<!DOCTYPE html>
<html lang="en">
<head>
    <meta http-equiv="content-type" content="text/html;charset=utf-8"/>
    <meta name="viewport" content="width=device-width, initial-scale=1.0"/>
    <meta name="description" content="An implementation of a transformer decode on a small text dataset in JAX from scratch, with implementations of basic layers like layer normalization and adam optimizer."/>

    <meta name="twitter:card" content="summary"/>
    <meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
    <meta name="twitter:title" content="Autoregressive Transformer Decoder in JAX from scratch"/>
    <meta name="twitter:description" content="An implementation of a transformer decode on a small text dataset in JAX from scratch, with implementations of basic layers like layer normalization and adam optimizer."/>
    <meta name="twitter:site" content="@labmlai"/>
    <meta name="twitter:creator" content="@labmlai"/>

    <meta property="og:url" content="https://nn.labml.ai/transformers/jax_transformer/index.html"/>
    <meta property="og:title" content="Autoregressive Transformer Decoder in JAX from scratch"/>
    <meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
    <meta property="og:site_name" content="Autoregressive Transformer Decoder in JAX from scratch"/>
    <meta property="og:type" content="object"/>
    <meta property="og:title" content="Autoregressive Transformer Decoder in JAX from scratch"/>
    <meta property="og:description" content="An implementation of a transformer decode on a small text dataset in JAX from scratch, with implementations of basic layers like layer normalization and adam optimizer."/>

    <title>Autoregressive Transformer Decoder in JAX from scratch</title>
    <link rel="shortcut icon" href="/icon.png"/>
    <link rel="stylesheet" href="../../pylit.css?v=1">
    <link rel="canonical" href="https://nn.labml.ai/transformers/jax_transformer/index.html"/>
    <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">

    <!-- Global site tag (gtag.js) - Google Analytics -->
    <script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
    <script>
        window.dataLayer = window.dataLayer || [];

        function gtag() {
            dataLayer.push(arguments);
        }

        gtag('js', new Date());

        gtag('config', 'G-4V3HC8HBLH');
    </script>
</head>
<body>
<div id='container'>
    <div id="background"></div>
    <div class='section'>
        <div class='docs'>
            <p>
                <a class="parent" href="/">home</a>
                <a class="parent" href="../index.html">transformers</a>
                <a class="parent" href="index.html">jax_transformer</a>
            </p>
            <p>
                <a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
                    <img alt="Github"
                         src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
                         style="max-width:100%;"/></a>
                <a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
                    <img alt="Twitter"
                         src="https://img.shields.io/twitter/follow/labmlai?style=social"
                         style="max-width:100%;"/></a>
            </p>
            <p>
                <a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/jax_transformer/__init__.py" target="_blank">
                    View code on Github</a>
            </p>
        </div>
    </div>
    <div class='section' id='section-0'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-0'>#</a>
            </div>
            <h1>Autoregressive Transformer Decoder in JAX from scratch</h1>
<h3>Contents</h3>
<ul><li><a href="#Module">Module class to help us write the layers</a> </li>
<li><a href="#Embedding">Embedding layer</a> </li>
<li><a href="#PositionalEmbedding">Positional embeddings</a> </li>
<li><a href="#Linear">Linear layer</a> </li>
<li><a href="#LayerNormalization">Layer Normalization</a> </li>
<li><a href="#MHA">Multi-head attention</a> </li>
<li><a href="#FFN">Position-wise Feed-Forward layer</a> </li>
<li><a href="#TransformerLayer">TransformerLayer layer</a> </li>
<li><a href="#CrossEntropyLoss">Cross Entropy Loss</a> </li>
<li><a href="#AutoregressiveTransformer">Autoregressive Transformer</a> </li>
<li><a href="#Adam">Adam Optimizer</a> </li>
<li><a href="#Dataset">Simple dataset</a> </li>
<li><a href="#Experiment">Experiment code</a></li></ul>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">28</span><span></span><span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">partial</span>
<span class="lineno">29</span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">NamedTuple</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Callable</span>
<span class="lineno">30</span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">TypeVar</span><span class="p">,</span> <span class="n">Generic</span>
<span class="lineno">31</span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Union</span><span class="p">,</span> <span class="n">Optional</span>
<span class="lineno">32</span>
<span class="lineno">33</span><span class="kn">import</span> <span class="nn">jax</span>
<span class="lineno">34</span><span class="kn">import</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span>
<span class="lineno">35</span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="lineno">36</span>
<span class="lineno">37</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">lab</span><span class="p">,</span> <span class="n">monit</span><span class="p">,</span> <span class="n">experiment</span><span class="p">,</span> <span class="n">tracker</span>
<span class="lineno">38</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">logger</span>
<span class="lineno">39</span><span class="kn">from</span> <span class="nn">labml.logger</span> <span class="kn">import</span> <span class="n">Text</span>
<span class="lineno">40</span><span class="kn">from</span> <span class="nn">labml.utils.download</span> <span class="kn">import</span> <span class="n">download_file</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-1'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-1'>#</a>
            </div>
            <p> <a id="Module"></a></p>
<h2>Module</h2>
<p>This is a base class for all modules. It handles parameters and transforms methods to pure functions for JAX to compile and differentiate.</p>
<p>You can skip these modules to get into the models directly.</p>
<p>The modules stores parameters and sub-modules separately. When we want to transform any method to a pure function, we pass the parameters of the module and the sub-module as an argument and assign the passed values to class.</p>
<p>This is based on a blog post:  <a href="https://sjmielke.com/jax-purify.htm">From PyTorch to JAX: towards neural net frameworks that purify stateful code</a>.</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">43</span><span class="k">class</span> <span class="nc">Module</span><span class="p">:</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-2'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-2'>#</a>
            </div>
            <p>Store all parameters and sub-modules in dictionaries </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">63</span>    <span class="n">_submodules</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="s1">&#39;Module&#39;</span><span class="p">]</span>
<span class="lineno">64</span>    <span class="n">_params</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-3'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-3'>#</a>
            </div>
            <p>Initialize </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">66</span>    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-4'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-4'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">68</span>        <span class="bp">self</span><span class="o">.</span><span class="n">_params</span> <span class="o">=</span> <span class="p">{}</span>
<span class="lineno">69</span>        <span class="bp">self</span><span class="o">.</span><span class="n">_submodules</span> <span class="o">=</span> <span class="p">{}</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-5'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-5'>#</a>
            </div>
            <h3>Get attribute</h3>
<p>We override the get attribute operation. So when you reference an attribute with <code  class="highlight"><span></span><span class="n">model</span><span class="o">.</span><span class="n">attribute</span></code>
 this function gets called.</p>
<p><a href="https://rszalski.github.io/magicmethods/">Read this guide</a> if you are not familiar with Python magic methods.</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">71</span>    <span class="k">def</span> <span class="fm">__getattr__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">attr_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-6'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-6'>#</a>
            </div>
            <p>If the attribute is a parameter </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">83</span>        <span class="k">if</span> <span class="n">attr_name</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_params</span><span class="p">:</span>
<span class="lineno">84</span>            <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_params</span><span class="p">[</span><span class="n">attr_name</span><span class="p">]</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-7'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-7'>#</a>
            </div>
            <p>If the attribute is a sub-module </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">86</span>        <span class="k">elif</span> <span class="n">attr_name</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_submodules</span><span class="p">:</span>
<span class="lineno">87</span>            <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_submodules</span><span class="p">[</span><span class="n">attr_name</span><span class="p">]</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-8'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-8'>#</a>
            </div>
            <p>Otherwise fallback to normal attributes. The attributes are stored in <code  class="highlight"><span></span><span class="vm">__dict__</span></code>
 by Python. </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">90</span>        <span class="k">else</span><span class="p">:</span>
<span class="lineno">91</span>            <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">[</span><span class="n">attr_name</span><span class="p">]</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-9'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-9'>#</a>
            </div>
            <h3>Set attribute</h3>
<p>We override the set attribute operation. So when you assign an attribute with <code  class="highlight"><span></span><span class="n">model</span><span class="o">.</span><span class="n">attribute</span></code>
 this function gets called.</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">93</span>    <span class="k">def</span> <span class="fm">__setattr__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">key</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">Any</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-10'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-10'>#</a>
            </div>
            <p>If the value is also a module </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">102</span>        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">value</span><span class="p">,</span> <span class="n">Module</span><span class="p">):</span>
<span class="lineno">103</span>            <span class="bp">self</span><span class="o">.</span><span class="n">_submodules</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">value</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-11'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-11'>#</a>
            </div>
            <p>If the value is a JAX array </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">105</span>        <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">value</span><span class="p">,</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ndarray</span><span class="p">):</span>
<span class="lineno">106</span>            <span class="bp">self</span><span class="o">.</span><span class="n">_params</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">value</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-12'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-12'>#</a>
            </div>
            <p>Otherwise add it to <code  class="highlight"><span></span><span class="vm">__dict__</span></code>
 </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">108</span>        <span class="k">else</span><span class="p">:</span>
<span class="lineno">109</span>            <span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">value</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-13'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-13'>#</a>
            </div>
            <h3>Clear parameters</h3>
<p>These clears out all the parameters. This is used when a method is called as a pure function. We first clears out all the parameters and assigns the parameters passed to the pure function.</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">111</span>    <span class="k">def</span> <span class="nf">_clear_params</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-14'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-14'>#</a>
            </div>
            <p>Clear parameters of the module </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">120</span>        <span class="bp">self</span><span class="o">.</span><span class="n">_params</span> <span class="o">=</span> <span class="p">{}</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-15'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-15'>#</a>
            </div>
            <p>Recursively clear parameters of submodules </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">122</span>        <span class="k">for</span> <span class="n">sm</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_submodules</span><span class="o">.</span><span class="n">values</span><span class="p">():</span>
<span class="lineno">123</span>            <span class="n">sm</span><span class="o">.</span><span class="n">_clear_params</span><span class="p">()</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-16'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-16'>#</a>
            </div>
            <h3>Collect all the parameters</h3>
<p>This recursively collects all the parameters of the module and sub-modules into a dictionary.</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">125</span>    <span class="k">def</span> <span class="nf">get_params</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]:</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-17'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-17'>#</a>
            </div>
            <p>Parameters of the model </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">133</span>        <span class="n">params</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_params</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-18'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-18'>#</a>
            </div>
            <p>Parameters of the submodules </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">135</span>        <span class="k">for</span> <span class="n">sm_name</span><span class="p">,</span> <span class="n">sm</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_submodules</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="lineno">136</span>            <span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">sm</span><span class="o">.</span><span class="n">get_params</span><span class="p">()</span><span class="o">.</span><span class="n">items</span><span class="p">():</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-19'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-19'>#</a>
            </div>
            <p>The dictionary keys are of the form <code  class="highlight"><span></span><span class="n">module_name</span><span class="o">/</span><span class="n">module_name</span><span class="o">/</span><span class="n">param_name</span></code>
 </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">138</span>                <span class="n">params</span><span class="p">[</span><span class="n">sm_name</span> <span class="o">+</span> <span class="s2">&quot;/&quot;</span> <span class="o">+</span> <span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">value</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-20'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-20'>#</a>
            </div>
            <p> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">140</span>        <span class="k">return</span> <span class="n">params</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-21'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-21'>#</a>
            </div>
            <h3>Set all the parameters</h3>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">142</span>    <span class="k">def</span> <span class="nf">_set_params</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">params</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-22'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-22'>#</a>
            </div>
            <p>Iterate through parameters. Their names have the form <code  class="highlight"><span></span><span class="n">module_name</span><span class="o">/</span><span class="n">module_name</span><span class="o">/</span><span class="n">param_name</span></code>
 </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">149</span>        <span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">params</span><span class="o">.</span><span class="n">items</span><span class="p">():</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-23'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-23'>#</a>
            </div>
            <p>Split to get module names and parameter name </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">151</span>            <span class="bp">self</span><span class="o">.</span><span class="n">_set_param</span><span class="p">(</span><span class="n">name</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">&quot;/&quot;</span><span class="p">),</span> <span class="n">value</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-24'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-24'>#</a>
            </div>
            <h3>Set a single parameter</h3>
<p>This is called by <code  class="highlight"><span></span><span class="n">_set_params</span></code>
</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">153</span>    <span class="k">def</span> <span class="nf">_set_param</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">param_path</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> <span class="n">value</span><span class="p">:</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ndarray</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-25'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-25'>#</a>
            </div>
            <p>No module names; i.e. a parameter of this module </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">160</span>        <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">param_path</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="lineno">161</span>            <span class="bp">self</span><span class="o">.</span><span class="n">_params</span><span class="p">[</span><span class="n">param_path</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span> <span class="o">=</span> <span class="n">value</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-26'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-26'>#</a>
            </div>
            <p>Parameter of a submodule </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">163</span>        <span class="k">else</span><span class="p">:</span>
<span class="lineno">164</span>            <span class="bp">self</span><span class="o">.</span><span class="n">_submodules</span><span class="p">[</span><span class="n">param_path</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span><span class="o">.</span><span class="n">_set_param</span><span class="p">(</span><span class="n">param_path</span><span class="p">[</span><span class="mi">1</span><span class="p">:],</span> <span class="n">value</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-27'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-27'>#</a>
            </div>
            <h3>Transform a member method to a pure function</h3>
<p>This transforms a member method to a pure function that accepts a dictionary of parameters as an argument.</p>
<p>For example,</p>
<pre  class="highlight lang-python"><code><span></span><span class="n">params</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">get_params</span><span class="p">()</span>
<span class="n">pure_function</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">purify</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">calculate_loss</span><span class="p">)</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">pure_function</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">data</span><span class="p">)</span></code></pre>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">166</span>    <span class="k">def</span> <span class="nf">purify</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">method</span><span class="p">:</span> <span class="n">Callable</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Callable</span><span class="p">:</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-28'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-28'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">182</span>        <span class="k">def</span> <span class="nf">pure_method</span><span class="p">(</span><span class="n">params</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">],</span> <span class="o">*</span><span class="n">args</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-29'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-29'>#</a>
            </div>
            <p>Clear parameters in the object </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">184</span>            <span class="bp">self</span><span class="o">.</span><span class="n">_clear_params</span><span class="p">()</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-30'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-30'>#</a>
            </div>
            <p>Assign the passed parameters </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">186</span>            <span class="bp">self</span><span class="o">.</span><span class="n">_set_params</span><span class="p">(</span><span class="n">params</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-31'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-31'>#</a>
            </div>
            <p>Invoke the method </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">188</span>            <span class="n">result</span> <span class="o">=</span> <span class="n">method</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-32'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-32'>#</a>
            </div>
            <p>Return the result </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">190</span>            <span class="k">return</span> <span class="n">result</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-33'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-33'>#</a>
            </div>
            <p> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">193</span>        <span class="k">return</span> <span class="n">pure_method</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-34'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-34'>#</a>
            </div>
            <p>Type for generics in the module list class </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">197</span><span class="n">M</span> <span class="o">=</span> <span class="n">TypeVar</span><span class="p">(</span><span class="s1">&#39;M&#39;</span><span class="p">,</span> <span class="n">bound</span><span class="o">=</span><span class="n">Module</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-35'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-35'>#</a>
            </div>
            <h2>Module list</h2>
<p>This stores a list of modules. We needed this for transformer decoder to hold the list of transformer layers.</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">200</span><span class="k">class</span> <span class="nc">ModuleList</span><span class="p">(</span><span class="n">Module</span><span class="p">,</span> <span class="n">Generic</span><span class="p">[</span><span class="n">M</span><span class="p">]):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-36'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-36'>#</a>
            </div>
            <p>For list of modules </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">209</span>    <span class="n">_submodules</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">M</span><span class="p">]</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-37'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-37'>#</a>
            </div>
            <p> Initialize with a list of modules.</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">211</span>    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">modules</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">M</span><span class="p">]):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-38'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-38'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">215</span>        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">216</span>        <span class="bp">self</span><span class="o">.</span><span class="n">_submodules</span> <span class="o">=</span> <span class="n">modules</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-39'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-39'>#</a>
            </div>
            <h3>Get the <code  class="highlight"><span></span><span class="n">idx</span></code>
-th module</h3>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">218</span>    <span class="k">def</span> <span class="fm">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">idx</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">M</span><span class="p">:</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-40'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-40'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">222</span>        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_submodules</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-41'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-41'>#</a>
            </div>
            <p> This is not supported</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">224</span>    <span class="k">def</span> <span class="fm">__setitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-42'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-42'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">228</span>        <span class="k">raise</span> <span class="ne">NotImplementedError</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-43'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-43'>#</a>
            </div>
            <h3>Number of modules</h3>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">230</span>    <span class="k">def</span> <span class="fm">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-44'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-44'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">234</span>        <span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_submodules</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-45'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-45'>#</a>
            </div>
            <p> Override <code  class="highlight"><span></span><span class="fm">__getattr__</span></code>
 of <code  class="highlight"><span></span><span class="n">Module</span></code>
</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">236</span>    <span class="k">def</span> <span class="fm">__getattr__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">item</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-46'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-46'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">240</span>        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">[</span><span class="n">item</span><span class="p">]</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-47'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-47'>#</a>
            </div>
            <p> Override <code  class="highlight"><span></span><span class="fm">__setattr__</span></code>
 of <code  class="highlight"><span></span><span class="n">Module</span></code>
</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">242</span>    <span class="k">def</span> <span class="fm">__setattr__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-48'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-48'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">246</span>        <span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">value</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-49'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-49'>#</a>
            </div>
            <h3>Clear all parameters</h3>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">248</span>    <span class="k">def</span> <span class="nf">_clear_params</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-50'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-50'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">252</span>        <span class="bp">self</span><span class="o">.</span><span class="n">_params</span> <span class="o">=</span> <span class="p">{}</span>
<span class="lineno">253</span>        <span class="k">for</span> <span class="n">sm</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_submodules</span><span class="p">:</span>
<span class="lineno">254</span>            <span class="n">sm</span><span class="o">.</span><span class="n">_clear_params</span><span class="p">()</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-51'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-51'>#</a>
            </div>
            <h3>Get all parameters</h3>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">256</span>    <span class="k">def</span> <span class="nf">get_params</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-52'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-52'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">260</span>        <span class="n">params</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_params</span>
<span class="lineno">261</span>        <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">sm</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_submodules</span><span class="p">):</span>
<span class="lineno">262</span>            <span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">sm</span><span class="o">.</span><span class="n">get_params</span><span class="p">()</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="lineno">263</span>                <span class="n">params</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">/</span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">value</span>
<span class="lineno">264</span>        <span class="k">return</span> <span class="n">params</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-53'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-53'>#</a>
            </div>
            <h3>Set a parameter</h3>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">266</span>    <span class="k">def</span> <span class="nf">_set_param</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">param_path</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> <span class="n">value</span><span class="p">:</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ndarray</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-54'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-54'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">270</span>        <span class="bp">self</span><span class="o">.</span><span class="n">_submodules</span><span class="p">[</span><span class="nb">int</span><span class="p">(</span><span class="n">param_path</span><span class="p">[</span><span class="mi">0</span><span class="p">])]</span><span class="o">.</span><span class="n">_set_param</span><span class="p">(</span><span class="n">param_path</span><span class="p">[</span><span class="mi">1</span><span class="p">:],</span> <span class="n">value</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-55'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-55'>#</a>
            </div>
            <p> <a id="Embedding"></a></p>
<h2>Embedding layer</h2>
<p>This maintains embeddings by id.</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">273</span><span class="k">class</span> <span class="nc">Embedding</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-56'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-56'>#</a>
            </div>
            <ul><li><code  class="highlight"><span></span><span class="n">rnd_key</span></code>
 is the PRNG state </li>
<li><code  class="highlight"><span></span><span class="n">n_embeddings</span></code>
 is the number of embeddings </li>
<li><code  class="highlight"><span></span><span class="n">n_dim</span></code>
 is the size of an embedding</li></ul>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">282</span>    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rnd_key</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">PRNGKey</span><span class="p">,</span> <span class="n">n_embeddings</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-57'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-57'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">288</span>        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-58'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-58'>#</a>
            </div>
            <p>Embeddings are initialized from <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathcal" style="margin-right:0.14736em;">N</span><span class="mopen">(</span><span class="mord coloredeq eqbz" style=""><span class="mord" style="">0</span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord">1</span><span class="mclose">)</span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">290</span>        <span class="bp">self</span><span class="o">.</span><span class="n">embeddings</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">rnd_key</span><span class="p">,</span> <span class="p">(</span><span class="n">n_embeddings</span><span class="p">,</span> <span class="n">n_dim</span><span class="p">))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-59'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-59'>#</a>
            </div>
            <p> Return the embeddings for the given ids</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">292</span>    <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">ids</span><span class="p">:</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ndarray</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-60'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-60'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">296</span>        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">embeddings</span><span class="p">[</span><span class="n">ids</span><span class="p">,</span> <span class="p">:]</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-61'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-61'>#</a>
            </div>
            <p> <a id="PositionalEmbedding"></a></p>
<h2>Embed tokens and add parameterized positional encodings</h2>
<p>This is based on <a href="https://nn.labml.ai/transformers/models.html#EmbeddingsWithLearnedPositionalEncoding">our PyTorch implementation</a>.</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">299</span><span class="k">class</span> <span class="nc">EmbeddingsWithLearnedPositionalEncoding</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-62'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-62'>#</a>
            </div>
            <ul><li><code  class="highlight"><span></span><span class="n">rnd_key</span></code>
 is the PRNG state </li>
<li><code  class="highlight"><span></span><span class="n">n_vocab</span></code>
 is the vocabulary size </li>
<li><code  class="highlight"><span></span><span class="n">d_model</span></code>
 is the embedding size </li>
<li><code  class="highlight"><span></span><span class="n">max_len</span></code>
 is the maximum sequence length (to initialize positional encodings)</li></ul>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">309</span>    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rnd_key</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">PRNGKey</span><span class="p">,</span> <span class="n">n_vocab</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">max_len</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">4096</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-63'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-63'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">316</span>        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-64'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-64'>#</a>
            </div>
            <p>Embeddings </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">318</span>        <span class="bp">self</span><span class="o">.</span><span class="n">embeddings</span> <span class="o">=</span> <span class="n">Embedding</span><span class="p">(</span><span class="n">rnd_key</span><span class="p">,</span> <span class="n">n_vocab</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-65'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-65'>#</a>
            </div>
            <p>Positional encodings coefficient <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.383108em;vertical-align:-0.538em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.845108em;"><span style="top:-2.5335085em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord sqrt mtight"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.937845em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mtight" style="padding-left:0.833em;"><span class="mord mathnormal mtight">d</span></span></span><span style="top:-2.8978450000000002em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail mtight" style="min-width:0.853em;height:1.08em;"><svg height="1.08em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1080" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.102155em;"><span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.394em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.538em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">320</span>        <span class="bp">self</span><span class="o">.</span><span class="n">pe_coef</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">d_model</span> <span class="o">**</span> <span class="mf">0.5</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-66'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-66'>#</a>
            </div>
            <p>Positional encodings initialized to zeros </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">322</span>        <span class="bp">self</span><span class="o">.</span><span class="n">positional_encodings</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">max_len</span><span class="p">,</span> <span class="n">d_model</span><span class="p">))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-67'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-67'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">324</span>    <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ndarray</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-68'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-68'>#</a>
            </div>
            <p>Get positional encodings </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">326</span>        <span class="n">pe</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">positional_encodings</span><span class="p">[:</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-69'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-69'>#</a>
            </div>
            <p>Get embeddings and add positional encodings </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">328</span>        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">embeddings</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">pe_coef</span> <span class="o">+</span> <span class="n">pe</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-70'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-70'>#</a>
            </div>
            <p> <a id="Linear"></a></p>
<h2>Linear Layer</h2>
<p>This is a simple linear layer with a weight matrix and a bias vector</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">331</span><span class="k">class</span> <span class="nc">Linear</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-71'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-71'>#</a>
            </div>
            <ul><li><code  class="highlight"><span></span><span class="n">rnd_key</span></code>
 is the PRNG state </li>
<li><code  class="highlight"><span></span><span class="n">in_features</span></code>
 is the number of features in the input </li>
<li><code  class="highlight"><span></span><span class="n">out_features</span></code>
 is the number of features in the output</li></ul>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">340</span>    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rnd_key</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">PRNGKey</span><span class="p">,</span> <span class="n">in_features</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">out_features</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-72'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-72'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">346</span>        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-73'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-73'>#</a>
            </div>
            <p>Initialize weights to <span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:3.0000299999999998em;vertical-align:-1.25003em;"></span><span class="mord mathcal" style="margin-right:0.09931em;">U</span><span class="mord"><span class="delimsizing size4">(</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:3.0000299999999998em;vertical-align:-1.25003em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.32144em;"><span style="top:-2.25278em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord sqrt"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.85722em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style="padding-left:0.833em;"><span class="mord"><span class="mord mathnormal">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">in</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span style="top:-2.81722em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail" style="min-width:0.853em;height:1.08em;"><svg height="1.08em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1080" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.18278000000000005em;"><span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.93em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.32144em;"><span style="top:-2.25278em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord sqrt"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.85722em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style="padding-left:0.833em;"><span class="mord"><span class="mord mathnormal">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">in</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span style="top:-2.81722em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail" style="min-width:0.853em;height:1.08em;"><svg height="1.08em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1080" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.18278000000000005em;"><span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.93em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mord"><span class="delimsizing size4">)</span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">349</span>        <span class="n">rnd_range</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">in_features</span> <span class="o">**</span> <span class="mf">0.5</span>
<span class="lineno">350</span>        <span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">rnd_key</span><span class="p">,</span> <span class="p">(</span><span class="n">in_features</span><span class="p">,</span> <span class="n">out_features</span><span class="p">),</span>
<span class="lineno">351</span>                                         <span class="n">minval</span><span class="o">=-</span><span class="n">rnd_range</span><span class="p">,</span> <span class="n">maxval</span><span class="o">=</span><span class="n">rnd_range</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-74'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-74'>#</a>
            </div>
            <p>Initialize the biases to <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqbz" style=""><span class="mord" style="">0</span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">353</span>        <span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">out_features</span><span class="p">,))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-75'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-75'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">355</span>    <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ndarray</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-76'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-76'>#</a>
            </div>
            <p>Multiply by weights and add the bias </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">357</span>        <span class="k">return</span> <span class="n">jnp</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-77'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-77'>#</a>
            </div>
            <p> <a id="LayerNormalization"></a></p>
<h2>Layer Normalization</h2>
<p>This implements the the layer normalization from the paper <a href="https://papers.labml.ai/paper/1607.06450">Layer Normalization</a>.</p>
<p>When input <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72243em;vertical-align:-0.0391em;"></span><span class="mord mathnormal" style="margin-right:0.07847em;">X</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">∈</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.8413309999999999em;vertical-align:0em;"></span><span class="mord"><span class="mord mathbb">R</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8413309999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqcb" style=""><span class="mord mathnormal mtight" style="">L</span></span><span class="mbin mtight">×</span><span class="mord mtight coloredeq eqca" style=""><span class="mord mathnormal mtight" style="margin-right:0.07153em">C</span></span></span></span></span></span></span></span></span></span></span></span></span></span> is a sequence of embeddings, where <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqca" style=""><span class="mord mathnormal" style="margin-right:0.07153em">C</span></span></span></span></span></span> is the number of channels, <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqcb" style=""><span class="mord mathnormal" style="">L</span></span></span></span></span></span> is the length of the sequence. <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.7335400000000001em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqbr" style=""><span class="mord mathnormal" style="margin-right:0.05556em">γ</span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">∈</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.8413309999999999em;vertical-align:0em;"></span><span class="mord"><span class="mord mathbb">R</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8413309999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqca" style=""><span class="mord mathnormal mtight" style="margin-right:0.07153em">C</span></span></span></span></span></span></span></span></span></span></span></span></span></span> and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqbs" style=""><span class="mord mathnormal" style="margin-right:0.05278em">β</span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">∈</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.8413309999999999em;vertical-align:0em;"></span><span class="mord"><span class="mord mathbb">R</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8413309999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqca" style=""><span class="mord mathnormal mtight" style="margin-right:0.07153em">C</span></span></span></span></span></span></span></span></span></span></span></span></span></span>. <span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord text"><span class="mord coloredeq eqcb" style=""><span class="mord" style="">L</span></span><span class="mord">N</span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.07847em;">X</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:3.614331em;vertical-align:-1.73em;"></span><span class="mord coloredeq eqbr" style=""><span class="mord mathnormal" style="margin-right:0.05556em">γ</span></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.884331em;"><span style="top:-2.1221655em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord sqrt"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.9878345000000001em;"><span class="svg-align" style="top:-3.8em;"><span class="pstrut" style="height:3.8em;"></span><span class="mord" style="padding-left:1em;"><span class="mord"><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.68333em;"><span style="top:-2.355669em;margin-left:0em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqca" style=""><span class="mord mathnormal mtight" style="margin-right:0.07153em">C</span></span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span><span class="mop"><span class="mord mathnormal">Va</span><span class="mord mathnormal" style="margin-right:0.02778em;">r</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.744331em;"><span></span></span></span></span></span></span><span class="mopen">[</span><span class="mord mathnormal" style="margin-right:0.07847em;">X</span><span class="mclose">]</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord coloredeq eqbj" style=""><span class="mord mathnormal" style="">ϵ</span></span></span></span><span style="top:-2.9478345em;"><span class="pstrut" style="height:3.8em;"></span><span class="hide-tail" style="min-width:1.02em;height:1.8800000000000001em;"><svg height="1.8800000000000001em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1944" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M983 90
l0 -0
c4,-6.7,10,-10,18,-10 H400000v40
H1013.1s-83.4,268,-264.1,840c-180.7,572,-277,876.3,-289,913c-4.7,4.7,-12.7,7,-24,7
s-12,0,-12,0c-1.3,-3.3,-3.7,-11.7,-7,-25c-35.3,-125.3,-106.7,-373.3,-214,-744
c-10,12,-21,25,-33,39s-32,39,-32,39c-6,-5.3,-15,-14,-27,-26s25,-30,25,-30
c26.7,-32.7,52,-63,76,-91s52,-60,52,-60s208,722,208,722
c56,-175.3,126.3,-397.3,211,-666c84.7,-268.7,153.8,-488.2,207.5,-658.5
c53.7,-170.3,84.5,-266.8,92.5,-289.5z
M1001 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.8521655em;"><span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-4.134331em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07847em;">X</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord"><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.68889em;"><span style="top:-2.355669em;margin-left:0em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqca" style=""><span class="mord mathnormal mtight" style="margin-right:0.07153em">C</span></span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span><span class="mop mathbb">E</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.744331em;"><span></span></span></span></span></span></span><span class="mopen">[</span><span class="mord mathnormal" style="margin-right:0.07847em;">X</span><span class="mclose">]</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.73em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqbs" style=""><span class="mord mathnormal" style="margin-right:0.05278em">β</span></span></span></span></span></span></span></p>
<p>This is based on <a href="https://nn.labml.ai/normalization/layer_norm/index.html">our PyTorch implementation</a>.</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">360</span><span class="k">class</span> <span class="nc">LayerNorm</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-78'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-78'>#</a>
            </div>
            <ul><li><code  class="highlight"><span></span><span class="n">normalized_shape</span></code>
 <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqcc" style=""><span class="mord mathnormal" style="margin-right:0.05764em">S</span></span></span></span></span></span> is the shape of the elements (except the batch).  The input should then be  <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72243em;vertical-align:-0.0391em;"></span><span class="mord mathnormal" style="margin-right:0.07847em;">X</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">∈</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.8879999999999999em;vertical-align:0em;"></span><span class="mord"><span class="mord mathbb">R</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8879999999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">∗×</span><span class="mord mtight coloredeq eqcc" style=""><span class="mord mathnormal mtight" style="margin-right:0.05764em">S</span></span><span class="mopen mtight">[</span><span class="mord mtight coloredeq eqbz" style=""><span class="mord mtight" style="">0</span></span><span class="mclose mtight">]</span><span class="mbin mtight">×</span><span class="mord mtight coloredeq eqcc" style=""><span class="mord mathnormal mtight" style="margin-right:0.05764em">S</span></span><span class="mopen mtight">[</span><span class="mord mtight">1</span><span class="mclose mtight">]</span><span class="mbin mtight">×</span><span class="mord mtight">...</span><span class="mbin mtight">×</span><span class="mord mtight coloredeq eqcc" style=""><span class="mord mathnormal mtight" style="margin-right:0.05764em">S</span></span><span class="mopen mtight">[</span><span class="mord mathnormal mtight">n</span><span class="mclose mtight">]</span></span></span></span></span></span></span></span></span></span></span></span></span> </li>
<li><code  class="highlight"><span></span><span class="n">eps</span></code>
 is <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqbj" style=""><span class="mord mathnormal" style="">ϵ</span></span></span></span></span></span>, used in <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.24em;vertical-align:-0.30499999999999994em;"></span><span class="mord coloredeq eqv" style=""><span class="mord sqrt" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.935em;"><span class="svg-align" style="top:-3.2em;"><span class="pstrut" style="height:3.2em;"></span><span class="mord" style="padding-left:1em"><span class="mord mathnormal" style="">Va</span><span class="mord mathnormal" style="margin-right:0.02778em">r</span><span class="mopen" style="">[</span><span class="mord mathnormal" style="margin-right:0.07847em">X</span><span class="mclose" style="">]</span><span class="mspace" style="margin-right:0.2222222222222222em"></span><span class="mbin" style="">+</span><span class="mspace" style="margin-right:0.2222222222222222em"></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqbj" style="">ϵ</span></span></span></span><span style="top:-2.8950000000000005em;"><span class="pstrut" style="height:3.2em;"></span><span class="hide-tail" style="min-width:1.02em;height:1.28em"><svg height="1.28em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1296" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M263,681c0.7,0,18,39.7,52,119
c34,79.3,68.167,158.7,102.5,238c34.3,79.3,51.8,119.3,52.5,120
c340,-704.7,510.7,-1060.3,512,-1067
l0 -0
c4.7,-7.3,11,-11,19,-11
H40000v40H1012.3
s-271.3,567,-271.3,567c-38.7,80.7,-84,175,-136,283c-52,108,-89.167,185.3,-111.5,232
c-22.3,46.7,-33.8,70.3,-34.5,71c-4.7,4.7,-12.3,7,-23,7s-12,-1,-12,-1
s-109,-253,-109,-253c-72.7,-168,-109.3,-252,-110,-252c-10.7,8,-22,16.7,-34,26
c-22,17.3,-33.3,26,-34,26s-26,-26,-26,-26s76,-59,76,-59s76,-60,76,-60z
M1001 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.30499999999999994em;"><span></span></span></span></span></span></span></span></span></span></span> for numerical stability </li>
<li><code  class="highlight"><span></span><span class="n">elementwise_affine</span></code>
 is whether to scale and shift the normalized value</li></ul>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">380</span>    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">normalized_shape</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">381</span>                 <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-5</span><span class="p">,</span> <span class="n">elementwise_affine</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-79'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-79'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">389</span>        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">390</span>
<span class="lineno">391</span>        <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
<span class="lineno">392</span>        <span class="bp">self</span><span class="o">.</span><span class="n">elementwise_affine</span> <span class="o">=</span> <span class="n">elementwise_affine</span>
<span class="lineno">393</span>        <span class="bp">self</span><span class="o">.</span><span class="n">normalized_shape</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">normalized_shape</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-80'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-80'>#</a>
            </div>
            <p>Create parameters for <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqbr" style=""><span class="mord mathnormal" style="margin-right:0.05556em">γ</span></span></span></span></span></span> and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqbs" style=""><span class="mord mathnormal" style="margin-right:0.05278em">β</span></span></span></span></span></span> for gain and bias </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">396</span>        <span class="k">if</span> <span class="n">elementwise_affine</span><span class="p">:</span>
<span class="lineno">397</span>            <span class="bp">self</span><span class="o">.</span><span class="n">gain</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">normalized_shape</span><span class="p">)</span>
<span class="lineno">398</span>            <span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">normalized_shape</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-81'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-81'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">400</span>    <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ndarray</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-82'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-82'>#</a>
            </div>
            <p>Sanity check to make sure the shapes match </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">402</span>        <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">normalized_shape</span> <span class="o">==</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">normalized_shape</span><span class="p">):]</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-83'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-83'>#</a>
            </div>
            <p>The exes to calculate the mean and variance on </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">405</span>        <span class="n">axes</span> <span class="o">=</span> <span class="p">[</span><span class="o">-</span><span class="p">(</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">normalized_shape</span><span class="p">))]</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-84'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-84'>#</a>
            </div>
            <p>Calculate the mean of all elements; i.e. the means for each element <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqbf" style=""><span class="mord mathbb" style="">E</span><span class="mopen" style="">[</span><span class="mord mathnormal" style="margin-right:0.07847em">X</span><span class="mclose" style="">]</span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">408</span>        <span class="n">mean</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="n">axes</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-85'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-85'>#</a>
            </div>
            <p>Calculate the squared mean of all elements; i.e. the means for each element <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.064108em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqba" style=""><span class="mord mathbb" style="">E</span><span class="mopen" style="">[</span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.07847em">X</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8141079999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span></span></span></span></span><span class="mclose" style="">]</span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">411</span>        <span class="n">mean_2</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="n">axes</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-86'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-86'>#</a>
            </div>
            <p>Variance of all element <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">Va</span><span class="mord mathnormal" style="margin-right:0.02778em;">r</span><span class="mopen">[</span><span class="mord mathnormal" style="margin-right:0.07847em;">X</span><span class="mclose">]</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.064108em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqba" style=""><span class="mord mathbb" style="">E</span><span class="mopen" style="">[</span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.07847em">X</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8141079999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span></span></span></span></span><span class="mclose" style="">]</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1.204008em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord coloredeq eqbf" style=""><span class="mord mathbb" style="">E</span><span class="mopen" style="">[</span><span class="mord mathnormal" style="margin-right:0.07847em">X</span><span class="mclose" style="">]</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.954008em;"><span style="top:-3.2029em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">413</span>        <span class="n">var</span> <span class="o">=</span> <span class="n">mean_2</span> <span class="o">-</span> <span class="n">mean</span> <span class="o">**</span> <span class="mi">2</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-87'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-87'>#</a>
            </div>
            <p>Normalize <span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.9467699999999999em;vertical-align:0em;"></span><span class="mord accent"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9467699999999999em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal" style="margin-right:0.07847em;">X</span></span><span style="top:-3.25233em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.16666em;"><span class="mord">^</span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:2.557em;vertical-align:-1.13em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.427em;"><span style="top:-2.175em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord coloredeq eqv" style=""><span class="mord sqrt" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.935em;"><span class="svg-align" style="top:-3.2em;"><span class="pstrut" style="height:3.2em;"></span><span class="mord" style="padding-left:1em"><span class="mord mathnormal" style="">Va</span><span class="mord mathnormal" style="margin-right:0.02778em">r</span><span class="mopen" style="">[</span><span class="mord mathnormal" style="margin-right:0.07847em">X</span><span class="mclose" style="">]</span><span class="mspace" style="margin-right:0.2222222222222222em"></span><span class="mbin" style="">+</span><span class="mspace" style="margin-right:0.2222222222222222em"></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqbj" style="">ϵ</span></span></span></span><span style="top:-2.8950000000000005em;"><span class="pstrut" style="height:3.2em;"></span><span class="hide-tail" style="min-width:1.02em;height:1.28em"><svg height="1.28em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1296" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M263,681c0.7,0,18,39.7,52,119
c34,79.3,68.167,158.7,102.5,238c34.3,79.3,51.8,119.3,52.5,120
c340,-704.7,510.7,-1060.3,512,-1067
l0 -0
c4.7,-7.3,11,-11,19,-11
H40000v40H1012.3
s-271.3,567,-271.3,567c-38.7,80.7,-84,175,-136,283c-52,108,-89.167,185.3,-111.5,232
c-22.3,46.7,-33.8,70.3,-34.5,71c-4.7,4.7,-12.3,7,-23,7s-12,-1,-12,-1
s-109,-253,-109,-253c-72.7,-168,-109.3,-252,-110,-252c-10.7,8,-22,16.7,-34,26
c-22,17.3,-33.3,26,-34,26s-26,-26,-26,-26s76,-59,76,-59s76,-60,76,-60z
M1001 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.30499999999999994em;"><span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07847em;">X</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord coloredeq eqbf" style=""><span class="mord mathbb" style="">E</span><span class="mopen" style="">[</span><span class="mord mathnormal" style="margin-right:0.07847em">X</span><span class="mclose" style="">]</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.13em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">415</span>        <span class="n">x_norm</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">mean</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span> <span class="o">**</span> <span class="mf">0.5</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-88'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-88'>#</a>
            </div>
            <p>Scale and shift <span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord text"><span class="mord coloredeq eqcb" style=""><span class="mord" style="">L</span></span><span class="mord">N</span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.1412099999999998em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqbr" style=""><span class="mord mathnormal" style="margin-right:0.05556em">γ</span></span><span class="mord accent"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9467699999999999em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal" style="margin-right:0.07847em;">X</span></span><span style="top:-3.25233em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.16666em;"><span class="mord">^</span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqbs" style=""><span class="mord mathnormal" style="margin-right:0.05278em">β</span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">418</span>        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">elementwise_affine</span><span class="p">:</span>
<span class="lineno">419</span>            <span class="n">x_norm</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">gain</span> <span class="o">*</span> <span class="n">x_norm</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-89'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-89'>#</a>
            </div>
            <p> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">422</span>        <span class="k">return</span> <span class="n">x_norm</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-90'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-90'>#</a>
            </div>
            <p> <a id="MHA"></a></p>
<h2>Multi-Head Attention Module</h2>
<p>This computes scaled multi-headed attention from the paper <a href="https://papers.labml.ai/paper/1706.03762">Attention Is All You Need</a> for given <code  class="highlight"><span></span><span class="n">query</span></code>
, <code  class="highlight"><span></span><span class="n">key</span></code>
 and <code  class="highlight"><span></span><span class="n">value</span></code>
 vectors.</p>
<p><span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mop"><span class="mord mathnormal">A</span><span class="mord coloredeq eqce" style=""><span class="mord mathnormal" style="">t</span></span><span class="mord coloredeq eqce" style=""><span class="mord mathnormal" style="">t</span></span><span class="mord mathnormal">e</span><span class="mord mathnormal">n</span><span class="mord coloredeq eqce" style=""><span class="mord mathnormal" style="">t</span></span><span class="mord mathnormal">i</span><span class="mord mathnormal">o</span><span class="mord mathnormal">n</span></span><span class="mopen">(</span><span class="mord mathnormal">Q</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathnormal" style="margin-right:0.07153em;">K</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathnormal" style="margin-right:0.22222em;">V</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:3.0000299999999998em;vertical-align:-1.25003em;"></span><span class="mord"><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.6944399999999998em;"><span style="top:-2.20556em;margin-left:0em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">se</span><span class="mord mathnormal mtight" style="margin-right:0.03588em;">q</span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span><span class="mop"><span class="mop"><span class="mord coloredeq eqbo" style=""><span class="mord mathnormal" style="">so</span><span class="mord" style=""><span class="mord mathnormal coloredeq eqcd" style="margin-right:0.10764em">f</span></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqce" style="">t</span></span><span class="mord mathnormal" style="">ma</span><span class="mord mathnormal" style="">x</span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.030548em;"><span></span></span></span></span></span></span><span class="mord"><span class="delimsizing size4">(</span></span><span class="mord coloredeq eqr" style=""><span class="mord" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.5261079999999998em;"><span style="top:-2.25278em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord sqrt" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.85722em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style="padding-left:0.833em"><span class="mord" style=""><span class="mord coloredeq eqbv" style=""><span class="mord mathnormal" style="">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03148em">k</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span><span style="top:-2.81722em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail" style="min-width:0.853em;height:1.08em"><svg height="1.08em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1080" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.18278000000000005em;"><span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqbi" style="">Q</span><span class="mord coloredeq eqbi" style=""><span class="mord mathnormal" style="margin-right:0.07153em">K</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.849108em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">⊤</span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.93em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span><span class="mord"><span class="delimsizing size4">)</span></span><span class="mord mathnormal" style="margin-right:0.22222em;">V</span></span></span></span></span></span></p>
<p>In simple terms, it finds keys that matches the query, and gets the values of  those keys.</p>
<p>It uses dot-product of query and key as the indicator of how matching they are. Before taking the <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqbo" style=""><span class="mord mathnormal" style="">so</span><span class="mord" style=""><span class="mord mathnormal coloredeq eqcd" style="margin-right:0.10764em">f</span></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqce" style="">t</span></span><span class="mord mathnormal" style="">ma</span><span class="mord mathnormal" style="">x</span></span></span></span></span></span> the dot-products are scaled by <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.383108em;vertical-align:-0.538em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.845108em;"><span style="top:-2.5864385em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord sqrt mtight"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8622307142857143em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mtight" style="padding-left:0.833em;"><span class="mord mtight coloredeq eqbv" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.3487714285714287em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03148em">k</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15122857142857138em;"><span></span></span></span></span></span></span></span></span></span><span style="top:-2.8222307142857144em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail mtight" style="min-width:0.853em;height:1.08em;"><svg height="1.08em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1080" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.17776928571428574em;"><span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.394em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.538em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span>. This is done to avoid large dot-product values causing softmax to give very small gradients when <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqbv" style=""><span class="mord" style=""><span class="mord mathnormal" style="">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03148em">k</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> is large.</p>
<p>Softmax is calculated along the axis of of the sequence (or time) for keys.</p>
<p>This is based on <a href="https://nn.labml.ai/transformers/mha.html#MHA">our PyTorch implementation</a>.</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">425</span><span class="k">class</span> <span class="nc">MultiHeadAttention</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-91'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-91'>#</a>
            </div>
            <ul><li><code  class="highlight"><span></span><span class="n">rnd_key</span></code>
 is the PRNG state </li>
<li><code  class="highlight"><span></span><span class="n">heads</span></code>
 is the number of heads. </li>
<li><code  class="highlight"><span></span><span class="n">d_model</span></code>
 is the number of features in the <code  class="highlight"><span></span><span class="n">query</span></code>
, <code  class="highlight"><span></span><span class="n">key</span></code>
 and <code  class="highlight"><span></span><span class="n">value</span></code>
 vectors.</li></ul>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">451</span>    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rnd_key</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">PRNGKey</span><span class="p">,</span> <span class="n">heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-92'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-92'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">458</span>        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-93'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-93'>#</a>
            </div>
            <p>Split the PRNG state </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">461</span>        <span class="n">_</span><span class="p">,</span> <span class="o">*</span><span class="n">rnd_keys</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">rnd_key</span><span class="p">,</span> <span class="mi">5</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-94'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-94'>#</a>
            </div>
            <p>Number of features per head </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">464</span>        <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span> <span class="o">=</span> <span class="n">d_model</span> <span class="o">//</span> <span class="n">heads</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-95'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-95'>#</a>
            </div>
            <p>Number of heads </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">466</span>        <span class="bp">self</span><span class="o">.</span><span class="n">heads</span> <span class="o">=</span> <span class="n">heads</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-96'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-96'>#</a>
            </div>
            <p>These transform the <code  class="highlight"><span></span><span class="n">query</span></code>
, <code  class="highlight"><span></span><span class="n">key</span></code>
 and <code  class="highlight"><span></span><span class="n">value</span></code>
 vectors for multi-headed attention. </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">469</span>        <span class="bp">self</span><span class="o">.</span><span class="n">query</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="n">rnd_keys</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span>
<span class="lineno">470</span>        <span class="bp">self</span><span class="o">.</span><span class="n">key</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="n">rnd_keys</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span>
<span class="lineno">471</span>        <span class="bp">self</span><span class="o">.</span><span class="n">value</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="n">rnd_keys</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-97'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-97'>#</a>
            </div>
            <p>Output layer </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">474</span>        <span class="bp">self</span><span class="o">.</span><span class="n">output</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="n">rnd_keys</span><span class="p">[</span><span class="mi">3</span><span class="p">],</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-98'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-98'>#</a>
            </div>
            <p>Scaling factor before the softmax </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">476</span>        <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span> <span class="o">**</span> <span class="mf">0.5</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-99'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-99'>#</a>
            </div>
            <p> <code  class="highlight"><span></span><span class="n">query</span></code>
, <code  class="highlight"><span></span><span class="n">key</span></code>
 and <code  class="highlight"><span></span><span class="n">value</span></code>
 are the tensors that store collection of <em>query</em>, <em>key</em> and <em>value</em> vectors. They have shape <code  class="highlight"><span></span><span class="p">[</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
.</p>
<p><code  class="highlight"><span></span><span class="n">mask</span></code>
 has shape <code  class="highlight"><span></span><span class="p">[</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">]</span></code>
 and <code  class="highlight"><span></span><span class="n">mask</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">]</span></code>
 indicates whether query at position <code  class="highlight"><span></span><span class="n">i</span></code>
 can see key-value at position <code  class="highlight"><span></span><span class="n">j</span></code>
.</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">478</span>    <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">479</span>                 <span class="n">query</span><span class="p">:</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span>
<span class="lineno">480</span>                 <span class="n">key</span><span class="p">:</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span>
<span class="lineno">481</span>                 <span class="n">value</span><span class="p">:</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span>
<span class="lineno">482</span>                 <span class="n">mask</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">jnp</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-100'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-100'>#</a>
            </div>
            <p>Get sequence length </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">493</span>        <span class="n">seq_len</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">query</span><span class="p">)</span>
<span class="lineno">494</span>
<span class="lineno">495</span>        <span class="k">if</span> <span class="n">mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-101'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-101'>#</a>
            </div>
            <p>Check mask shape </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">497</span>            <span class="k">assert</span> <span class="n">mask</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">query</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="lineno">498</span>            <span class="k">assert</span> <span class="n">mask</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="n">key</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-102'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-102'>#</a>
            </div>
            <p>Same mask applied to all heads. </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">501</span>            <span class="n">mask</span> <span class="o">=</span> <span class="n">mask</span><span class="p">[:,</span> <span class="p">:,</span> <span class="kc">None</span><span class="p">]</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-103'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-103'>#</a>
            </div>
            <p>Apply linear transformations </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">504</span>        <span class="n">query</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">query</span><span class="p">(</span><span class="n">query</span><span class="p">)</span>
<span class="lineno">505</span>        <span class="n">key</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="n">key</span><span class="p">)</span>
<span class="lineno">506</span>        <span class="n">value</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">value</span><span class="p">(</span><span class="n">value</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-104'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-104'>#</a>
            </div>
            <p>Reshape to split into heads Input has shape <code  class="highlight"><span></span><span class="p">[</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
. We split the last dimension into <code  class="highlight"><span></span><span class="n">heads</span></code>
 and <code  class="highlight"><span></span><span class="n">d_k</span></code>
. </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">511</span>        <span class="n">query</span> <span class="o">=</span> <span class="n">query</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">*</span><span class="n">query</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">)</span>
<span class="lineno">512</span>        <span class="n">key</span> <span class="o">=</span> <span class="n">key</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">*</span><span class="n">key</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">)</span>
<span class="lineno">513</span>        <span class="n">value</span> <span class="o">=</span> <span class="n">value</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">*</span><span class="n">value</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-105'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-105'>#</a>
            </div>
            <p>Compute attention scores <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.043548em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqbi" style=""><span class="mord mathnormal" style="">Q</span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.07153em">K</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.849108em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">⊤</span></span></span></span></span></span></span></span></span></span></span></span></span>. This gives a tensor of shape <code  class="highlight"><span></span><span class="p">[</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">heads</span><span class="p">]</span></code>
. <span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.969438em;vertical-align:-0.286108em;"></span><span class="mord"><span class="mord coloredeq eqcc" style=""><span class="mord mathnormal" style="margin-right:0.05764em">S</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361079999999999em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">ijh</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:2.3521180000000004em;vertical-align:-1.3021129999999999em;"></span><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.0500050000000005em;"><span style="top:-1.847887em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">d</span></span></span><span style="top:-3.050005em;"><span class="pstrut" style="height:3.05em;"></span><span><span class="mop op-symbol large-op">∑</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.3021129999999999em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal">Q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">ih</span><span class="mord mathnormal mtight">d</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07153em;">K</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361079999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.07153em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">jh</span><span class="mord mathnormal mtight">d</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">518</span>        <span class="n">scores</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">&#39;ihd,jhd-&gt;ijh&#39;</span><span class="p">,</span> <span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-106'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-106'>#</a>
            </div>
            <p>Scale scores <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.633028em;vertical-align:-0.538em;"></span><span class="mord coloredeq eqr" style=""><span class="mord" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.095028em;"><span style="top:-2.5864385em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord sqrt mtight" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8622307142857143em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mtight" style="padding-left:0.833em"><span class="mord mtight" style=""><span class="mord mtight coloredeq eqbv" style=""><span class="mord mathnormal mtight" style="">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.3487714285714287em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03148em">k</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15122857142857138em;"><span></span></span></span></span></span></span></span></span></span><span style="top:-2.8222307142857144em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail mtight" style="min-width:0.853em;height:1.08em"><svg height="1.08em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1080" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.17776928571428574em;"><span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.446108em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqbi" style="">Q</span><span class="mord mtight coloredeq eqbi" style=""><span class="mord mathnormal mtight" style="margin-right:0.07153em">K</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9270285714285713em;"><span style="top:-2.931em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style="">⊤</span></span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.538em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">521</span>        <span class="n">scores</span> <span class="o">*=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-107'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-107'>#</a>
            </div>
            <p>Apply mask </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">524</span>        <span class="k">if</span> <span class="n">mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">525</span>            <span class="n">scores</span> <span class="o">=</span> <span class="n">scores</span> <span class="o">+</span> <span class="p">(</span><span class="n">mask</span> <span class="o">==</span> <span class="mi">0</span><span class="p">)</span> <span class="o">*</span> <span class="nb">float</span><span class="p">(</span><span class="s1">&#39;-inf&#39;</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-108'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-108'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqbo" style=""><span class="mord mathnormal" style="">so</span><span class="mord" style=""><span class="mord mathnormal coloredeq eqcd" style="margin-right:0.10764em">f</span></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqce" style="">t</span></span><span class="mord mathnormal" style="">ma</span><span class="mord mathnormal" style="">x</span></span></span></span></span></span> attention along the key sequence dimension <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:3.0000299999999998em;vertical-align:-1.25003em;"></span><span class="mord coloredeq eqh" style=""><span class="mord" style=""><span class="mop op-limits" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.6944399999999998em;"><span style="top:-2.20556em;margin-left:0em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">se</span><span class="mord mathnormal mtight" style="margin-right:0.03588em">q</span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span><span class="mop" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqbo" style="">so</span><span class="mord coloredeq eqbo" style=""><span class="mord mathnormal coloredeq eqcd" style="margin-right:0.10764em">f</span></span><span class="mord coloredeq eqbo" style=""><span class="mord mathnormal coloredeq eqce" style="">t</span></span><span class="mord mathnormal coloredeq eqbo" style="">ma</span><span class="mord mathnormal coloredeq eqbo" style="">x</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.030548em;"><span></span></span></span></span></span></span><span class="mord" style=""><span class="delimsizing size4" style=""><span style="">(</span></span></span><span class="mord" style=""><span class="mord coloredeq eqr" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.095028em;"><span style="top:-2.5864385em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord sqrt mtight" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8622307142857143em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mtight" style="padding-left:0.833em"><span class="mord mtight" style=""><span class="mord mtight coloredeq eqbv" style=""><span class="mord mathnormal mtight" style="">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.3487714285714287em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03148em">k</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15122857142857138em;"><span></span></span></span></span></span></span></span></span></span><span style="top:-2.8222307142857144em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail mtight" style="min-width:0.853em;height:1.08em"><svg height="1.08em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1080" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.17776928571428574em;"><span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.446108em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqbi" style="">Q</span><span class="mord mtight coloredeq eqbi" style=""><span class="mord mathnormal mtight" style="margin-right:0.07153em">K</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9270285714285713em;"><span style="top:-2.931em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style="">⊤</span></span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.538em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span><span class="mord" style=""><span class="delimsizing size4" style=""><span style="">)</span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">529</span>        <span class="n">attn</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">scores</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-109'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-109'>#</a>
            </div>
            <p>Multiply by values <span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:3.0000299999999998em;vertical-align:-1.25003em;"></span><span class="mord coloredeq eqh" style=""><span class="mord" style=""><span class="mop op-limits" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.6944399999999998em;"><span style="top:-2.20556em;margin-left:0em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">se</span><span class="mord mathnormal mtight" style="margin-right:0.03588em">q</span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span><span class="mop" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqbo" style="">so</span><span class="mord coloredeq eqbo" style=""><span class="mord mathnormal coloredeq eqcd" style="margin-right:0.10764em">f</span></span><span class="mord coloredeq eqbo" style=""><span class="mord mathnormal coloredeq eqce" style="">t</span></span><span class="mord mathnormal coloredeq eqbo" style="">ma</span><span class="mord mathnormal coloredeq eqbo" style="">x</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.030548em;"><span></span></span></span></span></span></span><span class="mord" style=""><span class="delimsizing size4" style=""><span style="">(</span></span></span><span class="mord" style=""><span class="mord coloredeq eqr" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.5261079999999998em;"><span style="top:-2.25278em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord sqrt" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.85722em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style="padding-left:0.833em"><span class="mord" style=""><span class="mord coloredeq eqbv" style=""><span class="mord mathnormal" style="">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03148em">k</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span><span style="top:-2.81722em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail" style="min-width:0.853em;height:1.08em"><svg height="1.08em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1080" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.18278000000000005em;"><span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqbi" style="">Q</span><span class="mord coloredeq eqbi" style=""><span class="mord mathnormal" style="margin-right:0.07153em">K</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.849108em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">⊤</span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.93em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span><span class="mord" style=""><span class="delimsizing size4" style=""><span style="">)</span></span></span></span><span class="mord mathnormal" style="margin-right:0.22222em;">V</span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">533</span>        <span class="n">x</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s2">&quot;ijh,jhd-&gt;ihd&quot;</span><span class="p">,</span> <span class="n">attn</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-110'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-110'>#</a>
            </div>
            <p>Concatenate multiple heads </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">536</span>        <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">seq_len</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-111'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-111'>#</a>
            </div>
            <p>Output layer </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">539</span>        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">output</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-112'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-112'>#</a>
            </div>
            <p> <a id="FFN"></a></p>
<h2>Position-wise Feed-Forward layer</h2>
<p>This is based on <a href="https://nn.labml.ai/transformers/feed_forward.html">our PyTorch implementation</a>.</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">542</span><span class="k">class</span> <span class="nc">FeedForward</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-113'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-113'>#</a>
            </div>
            <ul><li><code  class="highlight"><span></span><span class="n">rnd_key</span></code>
 is the PRNG state </li>
<li><code  class="highlight"><span></span><span class="n">d_model</span></code>
 is the number of features in a token embedding </li>
<li><code  class="highlight"><span></span><span class="n">d_ff</span></code>
 is the number of features in the hidden layer of the FFN </li>
<li><code  class="highlight"><span></span><span class="n">activation</span></code>
 is the activation function <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqcd" style=""><span class="mord mathnormal" style="margin-right:0.10764em">f</span></span></span></span></span></span></li></ul>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">552</span>    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rnd_key</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">PRNGKey</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="lineno">553</span>                 <span class="n">activation</span><span class="o">=</span><span class="n">jax</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">relu</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-114'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-114'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">560</span>        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-115'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-115'>#</a>
            </div>
            <p>Split the PRNG state </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">562</span>        <span class="n">_</span><span class="p">,</span> <span class="o">*</span><span class="n">rnd_keys</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">rnd_key</span><span class="p">,</span> <span class="mi">5</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-116'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-116'>#</a>
            </div>
            <p>Layer one parameterized by weight <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqbt" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> and bias <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqbu" style=""><span class="mord" style=""><span class="mord mathnormal" style="">b</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">565</span>        <span class="bp">self</span><span class="o">.</span><span class="n">layer1</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="n">rnd_keys</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-117'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-117'>#</a>
            </div>
            <p>Layer one parameterized by weight <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqbt" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> and bias <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqbu" style=""><span class="mord" style=""><span class="mord mathnormal" style="">b</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">567</span>        <span class="bp">self</span><span class="o">.</span><span class="n">layer2</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="n">rnd_keys</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">d_ff</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-118'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-118'>#</a>
            </div>
            <p>Activation function <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqcd" style=""><span class="mord mathnormal" style="margin-right:0.10764em">f</span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">569</span>        <span class="bp">self</span><span class="o">.</span><span class="n">activation</span> <span class="o">=</span> <span class="n">activation</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-119'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-119'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">571</span>    <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ndarray</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-120'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-120'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqbc" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcd" style="margin-right:0.10764em">f</span></span><span class="mopen" style="">(</span><span class="mord mathnormal" style="">x</span><span class="mord" style=""><span class="mord coloredeq eqbt" style=""><span class="mord mathnormal" style="margin-right:0.13889em">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style=""><span class="mord coloredeq eqbu" style=""><span class="mord mathnormal" style="">b</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mclose" style="">)</span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">573</span>        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">layer1</span><span class="p">(</span><span class="n">x</span><span class="p">))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-121'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-121'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqbc" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqcd" style="margin-right:0.10764em">f</span></span><span class="mopen" style="">(</span><span class="mord mathnormal" style="">x</span><span class="mord" style=""><span class="mord coloredeq eqbt" style=""><span class="mord mathnormal" style="margin-right:0.13889em">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style=""><span class="mord coloredeq eqbu" style=""><span class="mord mathnormal" style="">b</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mclose" style="">)</span></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal">b</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">575</span>        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-122'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-122'>#</a>
            </div>
            <p> <a id="TransformerLayer"></a></p>
<h2>Transformer Layer</h2>
<p>This is a transformer layer with multi-head attention and a position-wise feed-forward layer. We use pre-layer layer normalization.</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">578</span><span class="k">class</span> <span class="nc">TransformerLayer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-123'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-123'>#</a>
            </div>
            <ul><li><code  class="highlight"><span></span><span class="n">d_model</span></code>
 is the token embedding size </li>
<li><code  class="highlight"><span></span><span class="n">self_attn</span></code>
 is the self attention module </li>
<li><code  class="highlight"><span></span><span class="n">feed_forward</span></code>
 is the feed forward module</li></ul>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">588</span>    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="lineno">589</span>                 <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="lineno">590</span>                 <span class="n">self_attn</span><span class="p">:</span> <span class="n">MultiHeadAttention</span><span class="p">,</span>
<span class="lineno">591</span>                 <span class="n">feed_forward</span><span class="p">:</span> <span class="n">FeedForward</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-124'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-124'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">597</span>        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">598</span>        <span class="bp">self</span><span class="o">.</span><span class="n">size</span> <span class="o">=</span> <span class="n">d_model</span>
<span class="lineno">599</span>        <span class="bp">self</span><span class="o">.</span><span class="n">self_attn</span> <span class="o">=</span> <span class="n">self_attn</span>
<span class="lineno">600</span>        <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span> <span class="o">=</span> <span class="n">feed_forward</span>
<span class="lineno">601</span>        <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span> <span class="o">=</span> <span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span>
<span class="lineno">602</span>        <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span> <span class="o">=</span> <span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-125'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-125'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">604</span>    <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">mask</span><span class="p">:</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ndarray</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-126'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-126'>#</a>
            </div>
            <p>Normalize the vectors before doing self attention </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">606</span>        <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-127'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-127'>#</a>
            </div>
            <p>Run through self attention, i.e. keys and values are from self </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">608</span>        <span class="n">self_attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">self_attn</span><span class="p">(</span><span class="n">query</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="lineno">609</span>        <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">self_attn</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-128'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-128'>#</a>
            </div>
            <p>Normalize for feed-forward </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">612</span>        <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-129'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-129'>#</a>
            </div>
            <p>Pass through the feed-forward network </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">614</span>        <span class="n">ff</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span><span class="p">(</span><span class="n">z</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-130'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-130'>#</a>
            </div>
            <p>Add the feed-forward results </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">616</span>        <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">ff</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-131'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-131'>#</a>
            </div>
            <p> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">618</span>        <span class="k">return</span> <span class="n">x</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-132'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-132'>#</a>
            </div>
            <p> <a id="CrossEntropyLoss"></a></p>
<h2>Cross Entropy Loss</h2>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">621</span><span class="k">class</span> <span class="nc">CrossEntropyLoss</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-133'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-133'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">628</span>    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="lineno">629</span>        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-134'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-134'>#</a>
            </div>
            <p>Use <code  class="highlight"><span></span><span class="n">jax</span><span class="o">.</span><span class="n">vmap</span></code>
 to vectorize the loss function </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">632</span>        <span class="bp">self</span><span class="o">.</span><span class="n">_loss_vmap</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">vmap</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_loss</span><span class="p">,</span> <span class="n">in_axes</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-135'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-135'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">634</span>    <span class="k">def</span> <span class="nf">_loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">output</span><span class="p">:</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ndarray</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-136'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-136'>#</a>
            </div>
            <p><span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:2.3521180000000004em;vertical-align:-1.3021129999999999em;"></span><span class="mord">−</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.0500050000000005em;"><span style="top:-1.847887em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span></span></span><span style="top:-3.050005em;"><span class="pstrut" style="height:3.05em;"></span><span><span class="mop op-symbol large-op">∑</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.3021129999999999em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop">lo<span style="margin-right:0.01389em;">g</span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord accent"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal" style="margin-right:0.03588em;">y</span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.19444em;"><span class="mord">^</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.19444em;"><span></span></span></span></span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">636</span>        <span class="k">return</span> <span class="o">-</span><span class="n">jax</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">log_softmax</span><span class="p">(</span><span class="n">output</span><span class="p">)[</span><span class="n">target</span><span class="p">]</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-137'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-137'>#</a>
            </div>
            <ul><li><code  class="highlight"><span></span><span class="n">output</span></code>
 is the model outputs of shape <code  class="highlight"><span></span><span class="p">[</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">n_vocab</span><span class="p">]</span></code>
 </li>
<li><code  class="highlight"><span></span><span class="n">target</span></code>
 is the target of shape <code  class="highlight"><span></span><span class="p">[</span><span class="n">seq_len</span><span class="p">]</span></code>
</li></ul>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">638</span>    <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">output</span><span class="p">:</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ndarray</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-138'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-138'>#</a>
            </div>
            <p>Use the vectorized loss function and calculate the mean.</p>
<p>We could have used a for loop to calculate the losses but using vmap is about 10X faster </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">647</span>        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_loss_vmap</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-139'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-139'>#</a>
            </div>
            <p> <a id="AutoregressiveTransformer"></a></p>
<h2>Autoregressive Transformer</h2>
<p>This is the transformer decode with embedding and output layers.</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">650</span><span class="k">class</span> <span class="nc">AutoregressiveTransformer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-140'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-140'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">658</span>    <span class="n">layers</span><span class="p">:</span> <span class="n">ModuleList</span><span class="p">[</span><span class="n">TransformerLayer</span><span class="p">]</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-141'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-141'>#</a>
            </div>
            <ul><li><code  class="highlight"><span></span><span class="n">rnd_key</span></code>
 is the PRNG state </li>
<li><code  class="highlight"><span></span><span class="n">n_vocab</span></code>
 is the vocabulary size </li>
<li><code  class="highlight"><span></span><span class="n">d_model</span></code>
 is the number of features in a token embedding </li>
<li><code  class="highlight"><span></span><span class="n">n_layers</span></code>
 is the number of transformer layers </li>
<li><code  class="highlight"><span></span><span class="n">heads</span></code>
 is the number of attention heads </li>
<li><code  class="highlight"><span></span><span class="n">d_ff</span></code>
 is the number of features in the hidden layer of the FFN</li></ul>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">660</span>    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rnd_key</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">PRNGKey</span><span class="p">,</span> <span class="n">n_vocab</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-142'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-142'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">669</span>        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">670</span>        <span class="bp">self</span><span class="o">.</span><span class="n">n_vocab</span> <span class="o">=</span> <span class="n">n_vocab</span>
<span class="lineno">671</span>        <span class="bp">self</span><span class="o">.</span><span class="n">d_model</span> <span class="o">=</span> <span class="n">d_model</span>
<span class="lineno">672</span>        <span class="bp">self</span><span class="o">.</span><span class="n">loss_func</span> <span class="o">=</span> <span class="n">CrossEntropyLoss</span><span class="p">()</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-143'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-143'>#</a>
            </div>
            <p>For transformer layers </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">675</span>        <span class="n">layers</span> <span class="o">=</span> <span class="p">[]</span>
<span class="lineno">676</span>        <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_layers</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-144'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-144'>#</a>
            </div>
            <p>Split PRNG state </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">678</span>            <span class="n">rnd_key</span><span class="p">,</span> <span class="n">mha_key</span><span class="p">,</span> <span class="n">ffn_key</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">rnd_key</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-145'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-145'>#</a>
            </div>
            <p>Create a transformer layer </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">680</span>            <span class="n">attn</span> <span class="o">=</span> <span class="n">MultiHeadAttention</span><span class="p">(</span><span class="n">mha_key</span><span class="p">,</span> <span class="n">heads</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span>
<span class="lineno">681</span>            <span class="n">ffn</span> <span class="o">=</span> <span class="n">FeedForward</span><span class="p">(</span><span class="n">ffn_key</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">)</span>
<span class="lineno">682</span>            <span class="n">layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">TransformerLayer</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">attn</span><span class="p">,</span> <span class="n">ffn</span><span class="p">))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-146'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-146'>#</a>
            </div>
            <p>Make a module list </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">684</span>        <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">ModuleList</span><span class="p">(</span><span class="n">layers</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-147'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-147'>#</a>
            </div>
            <p>Split PRNG state </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">687</span>        <span class="n">rnd_key</span><span class="p">,</span> <span class="n">emb_key</span><span class="p">,</span> <span class="n">out_key</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">rnd_key</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-148'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-148'>#</a>
            </div>
            <p>Create embedding layer </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">689</span>        <span class="bp">self</span><span class="o">.</span><span class="n">embeddings</span> <span class="o">=</span> <span class="n">EmbeddingsWithLearnedPositionalEncoding</span><span class="p">(</span><span class="n">emb_key</span><span class="p">,</span> <span class="n">n_vocab</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-149'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-149'>#</a>
            </div>
            <p>Final normalization and output layer </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">691</span>        <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span>
<span class="lineno">692</span>        <span class="bp">self</span><span class="o">.</span><span class="n">output</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="n">out_key</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">n_vocab</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-150'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-150'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">694</span>    <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ndarray</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-151'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-151'>#</a>
            </div>
            <p>Get sequence length </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">696</span>        <span class="n">seq_len</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-152'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-152'>#</a>
            </div>
            <p>A mask for attention so that a token can only see tokens before that </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">698</span>        <span class="n">mask</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">tril</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">),</span> <span class="nb">bool</span><span class="p">))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-153'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-153'>#</a>
            </div>
            <p>Get embeddings with positional encodings </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">700</span>        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">embeddings</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-154'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-154'>#</a>
            </div>
            <p>Apply the transformer layers </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">702</span>        <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">)):</span>
<span class="lineno">703</span>            <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">[</span><span class="n">i</span><span class="p">](</span><span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-155'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-155'>#</a>
            </div>
            <p>Final normalization and linear transformation to get the logits </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">706</span>        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">output</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-156'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-156'>#</a>
            </div>
            <h3>Calculate the loss</h3>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">708</span>    <span class="k">def</span> <span class="nf">get_loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ndarray</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-157'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-157'>#</a>
            </div>
            <p>Get model outputs </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">713</span>        <span class="n">output</span> <span class="o">=</span> <span class="bp">self</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-158'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-158'>#</a>
            </div>
            <p>Cross entropy loss </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">715</span>        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_func</span><span class="p">(</span><span class="n">output</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">x</span><span class="p">[</span><span class="mi">1</span><span class="p">:])</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-159'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-159'>#</a>
            </div>
            <h3>Sample</h3>
<p>The starting sequence is given by <code  class="highlight"><span></span><span class="n">seq</span></code>
 and we greedily sample `length1 tokens</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">717</span>    <span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">seq</span><span class="p">:</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">length</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">20</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-160'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-160'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">723</span>        <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">length</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-161'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-161'>#</a>
            </div>
            <p>Sample the highest probability token </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">725</span>            <span class="n">idx</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="bp">self</span><span class="p">(</span><span class="n">seq</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-162'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-162'>#</a>
            </div>
            <p>Add it to the sequence </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">727</span>            <span class="n">seq</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">concatenate</span><span class="p">((</span><span class="n">seq</span><span class="p">,</span> <span class="n">idx</span><span class="p">[</span><span class="kc">None</span><span class="p">]))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-163'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-163'>#</a>
            </div>
            <p>Return the sampled sequence </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">730</span>        <span class="k">return</span> <span class="n">seq</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-164'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-164'>#</a>
            </div>
            <p> This is a named tuple for storing Adam optimizer state for a parameter</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">733</span><span class="k">class</span> <span class="nc">AdamState</span><span class="p">(</span><span class="n">NamedTuple</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-165'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-165'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">737</span>    <span class="n">m</span><span class="p">:</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ndarray</span>
<span class="lineno">738</span>    <span class="n">v</span><span class="p">:</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ndarray</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-166'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-166'>#</a>
            </div>
            <p> <a id="Adam"></a></p>
<h2>Adam Optimizer</h2>
<p>This is from paper  <a href="https://papers.labml.ai/paper/1412.6980">Adam: A Method for Stochastic Optimization</a>.</p>
<p>For parameter <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqbk" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.02778em">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> and gradient <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqbw" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">g</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> at step <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.61508em;vertical-align:0em;"></span><span class="mord coloredeq eqce" style=""><span class="mord mathnormal" style="">t</span></span></span></span></span></span>, the Adam update is,</p>
<span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:10.365339999999996em;vertical-align:-4.932669999999998em;"></span><span class="mord"><span class="mtable"><span class="col-align-r"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:5.432669999999999em;"><span style="top:-7.964109999999999em;"><span class="pstrut" style="height:3.3714399999999998em;"></span><span class="mord"><span class="mord coloredeq eqbx" style=""><span class="mord" style=""><span class="mord mathnormal" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span><span style="top:-6.440001999999999em;"><span class="pstrut" style="height:3.3714399999999998em;"></span><span class="mord"><span class="mord coloredeq eqby" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">v</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span><span style="top:-4.672441999999999em;"><span class="pstrut" style="height:3.3714399999999998em;"></span><span class="mord"><span class="mord coloredeq eqbg" style=""><span class="mord" style=""><span class="mord accent" style=""><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal" style="">m</span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord" style="">^</span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span><span style="top:-2.3025460000000004em;"><span class="pstrut" style="height:3.3714399999999998em;"></span><span class="mord"><span class="mord coloredeq eqbh" style=""><span class="mord" style=""><span class="mord accent" style=""><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal" style="margin-right:0.03588em">v</span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.22222em;"><span class="mord" style="">^</span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span><span style="top:0.33122999999999836em;"><span class="pstrut" style="height:3.3714399999999998em;"></span><span class="mord"><span class="mord coloredeq eqbk" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.02778em">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:4.932669999999998em;"><span></span></span></span></span></span><span class="col-align-l"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:5.432669999999999em;"><span style="top:-7.964109999999999em;"><span class="pstrut" style="height:3.3714399999999998em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">←</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord coloredeq eqbl" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqbs" style="margin-right:0.05278em">β</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mord coloredeq eqbn" style=""><span class="mord" style=""><span class="mord mathnormal" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.301108em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span><span class="mbin mtight" style="">−</span><span class="mord mtight" style="">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mopen">(</span><span class="mord">1</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord coloredeq eqbl" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqbs" style="margin-right:0.05278em">β</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">⋅</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord coloredeq eqbw" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">g</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span><span style="top:-6.440001999999999em;"><span class="pstrut" style="height:3.3714399999999998em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">←</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord coloredeq eqbm" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqbs" style="margin-right:0.05278em">β</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mord coloredeq eqbp" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">v</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.301108em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span><span class="mbin mtight" style="">−</span><span class="mord mtight" style="">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mopen">(</span><span class="mord">1</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord coloredeq eqbm" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqbs" style="margin-right:0.05278em">β</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">⋅</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord"><span class="mord coloredeq eqbw" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">g</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8641079999999999em;"><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span></span></span><span style="top:-4.672441999999999em;"><span class="pstrut" style="height:3.3714399999999998em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">←</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.1075599999999999em;"><span style="top:-2.2321040000000005em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord">1</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord"><span class="mord coloredeq eqbl" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqbs" style="margin-right:0.05278em">β</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8778959999999999em;"><span style="top:-3.1473400000000002em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqce" style=""><span class="mord mathnormal mtight" style="">t</span></span></span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord coloredeq eqbx" style=""><span class="mord" style=""><span class="mord mathnormal" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.9623359999999999em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span><span style="top:-2.3025460000000004em;"><span class="pstrut" style="height:3.3714399999999998em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">←</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.1075599999999999em;"><span style="top:-2.2321040000000005em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord">1</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord"><span class="mord coloredeq eqbm" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqbs" style="margin-right:0.05278em">β</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8778959999999999em;"><span style="top:-3.1473400000000002em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqce" style=""><span class="mord mathnormal mtight" style="">t</span></span></span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord coloredeq eqby" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">v</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.9623359999999999em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span><span style="top:0.33122999999999836em;"><span class="pstrut" style="height:3.3714399999999998em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">←</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.301108em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqce" style=""><span class="mord mathnormal mtight" style="">t</span></span><span class="mbin mtight">−</span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord coloredeq eqbq" style=""><span class="mord mathnormal" style="margin-right:0.0037em">α</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">⋅</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.37144em;"><span style="top:-2.25278em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord sqrt"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.85722em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style="padding-left:0.833em;"><span class="mord coloredeq eqbh" style=""><span class="mord" style=""><span class="mord accent" style=""><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal" style="margin-right:0.03588em">v</span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.22222em;"><span class="mord" style="">^</span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span><span style="top:-2.81722em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail" style="min-width:0.853em;height:1.08em;"><svg height="1.08em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1080" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.18278000000000005em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord coloredeq eqbj" style=""><span class="mord mathnormal" style="">ϵ</span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord coloredeq eqbg" style=""><span class="mord" style=""><span class="mord accent" style=""><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal" style="">m</span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord" style="">^</span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.93em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:4.932669999999998em;"><span></span></span></span></span></span></span></span></span></span></span></span></span><p>where <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqbq" style=""><span class="mord mathnormal" style="margin-right:0.0037em">α</span></span></span></span></span></span>, <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqbl" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqbs" style="margin-right:0.05278em">β</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span>, <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqbm" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqbs" style="margin-right:0.05278em">β</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqbj" style=""><span class="mord mathnormal" style="">ϵ</span></span></span></span></span></span> are scalar hyper parameters. <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqbx" style=""><span class="mord" style=""><span class="mord mathnormal" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqby" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">v</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> are first and second order moments. <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqbg" style=""><span class="mord" style=""><span class="mord accent" style=""><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal" style="">m</span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord" style="">^</span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqbh" style=""><span class="mord" style=""><span class="mord accent" style=""><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal" style="margin-right:0.03588em">v</span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.22222em;"><span class="mord" style="">^</span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> are biased corrected moments. <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqbj" style=""><span class="mord mathnormal" style="">ϵ</span></span></span></span></span></span> is used as a fix for division by zero error, but also acts as a form of a hyper-parameter that acts against variance in gradients.</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">741</span><span class="k">class</span> <span class="nc">Adam</span><span class="p">:</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-167'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-167'>#</a>
            </div>
            <ul><li><code  class="highlight"><span></span><span class="n">params</span></code>
 is the tree-map of parameters </li>
<li><code  class="highlight"><span></span><span class="n">lr</span></code>
 is the learning rate <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqbq" style=""><span class="mord mathnormal" style="margin-right:0.0037em">α</span></span></span></span></span></span> </li>
<li><code  class="highlight"><span></span><span class="n">betas</span></code>
 is a tuple of (<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqbl" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqbs" style="margin-right:0.05278em">β</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span>, <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqbm" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqbs" style="margin-right:0.05278em">β</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span>) </li>
<li><code  class="highlight"><span></span><span class="n">eps</span></code>
 is <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord coloredeq eqbb" style=""><span class="mord accent" style=""><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqbj" style="">ϵ</span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.19444em;"><span class="mord" style="">^</span></span></span></span></span></span></span></span></span></span></span></span>`</li></ul>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">767</span>    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">params</span><span class="p">:</span> <span class="n">Dict</span><span class="p">,</span>
<span class="lineno">768</span>                 <span class="n">lr</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.001</span><span class="p">,</span> <span class="n">betas</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mf">0.9</span><span class="p">,</span> <span class="mf">0.999</span><span class="p">),</span>
<span class="lineno">769</span>                 <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-16</span><span class="p">,</span> <span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-168'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-168'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">777</span>        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">778</span>        <span class="bp">self</span><span class="o">.</span><span class="n">lr</span> <span class="o">=</span> <span class="n">lr</span>
<span class="lineno">779</span>        <span class="bp">self</span><span class="o">.</span><span class="n">betas</span> <span class="o">=</span> <span class="n">betas</span>
<span class="lineno">780</span>        <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-169'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-169'>#</a>
            </div>
            <p>States for each parameter </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">783</span>        <span class="bp">self</span><span class="o">.</span><span class="n">states</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">tree</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_init_state</span><span class="p">,</span> <span class="n">params</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-170'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-170'>#</a>
            </div>
            <p>Optimized step function </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">785</span>        <span class="bp">self</span><span class="o">.</span><span class="n">_step_jit</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">jit</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_step</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-171'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-171'>#</a>
            </div>
            <p>Number of steps taken <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.61508em;vertical-align:0em;"></span><span class="mord coloredeq eqce" style=""><span class="mord mathnormal" style="">t</span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">787</span>        <span class="bp">self</span><span class="o">.</span><span class="n">_n_steps</span> <span class="o">=</span> <span class="mi">0</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-172'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-172'>#</a>
            </div>
            <p>Optimized update state function </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">789</span>        <span class="bp">self</span><span class="o">.</span><span class="n">_update_state_jit</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">jit</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_update_state</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-173'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-173'>#</a>
            </div>
            <p> Initialize the state for a given parameter</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">791</span>    <span class="k">def</span> <span class="nf">_init_state</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">param</span><span class="p">:</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ndarray</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-174'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-174'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">795</span>        <span class="k">return</span> <span class="n">AdamState</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">param</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">param</span><span class="p">))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-175'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-175'>#</a>
            </div>
            <h2>Step function</h2>
<ul><li><code  class="highlight"><span></span><span class="n">params</span></code>
 is a tree-map of parameters </li>
<li><code  class="highlight"><span></span><span class="n">grads</span></code>
 is a tree-map of gradients</li></ul>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">797</span>    <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">params</span><span class="p">:</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">grads</span><span class="p">:</span> <span class="n">Dict</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-176'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-176'>#</a>
            </div>
            <p>Increment step <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.61508em;vertical-align:0em;"></span><span class="mord coloredeq eqce" style=""><span class="mord mathnormal" style="">t</span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">805</span>        <span class="bp">self</span><span class="o">.</span><span class="n">_n_steps</span> <span class="o">+=</span> <span class="mi">1</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-177'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-177'>#</a>
            </div>
            <p>Update states for each parameter </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">807</span>        <span class="bp">self</span><span class="o">.</span><span class="n">states</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">tree</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_update_state_jit</span><span class="p">,</span> <span class="n">grads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">states</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-178'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-178'>#</a>
            </div>
            <p>Return updated parameters <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqbk" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.02778em">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">809</span>        <span class="k">return</span> <span class="n">jax</span><span class="o">.</span><span class="n">tree</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">partial</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_step_jit</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_n_steps</span><span class="p">),</span> <span class="n">params</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">states</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-179'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-179'>#</a>
            </div>
            <h3>Update parameters</h3>
<p>This performs a Adam update on the given parameter</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">811</span>    <span class="k">def</span> <span class="nf">_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_steps</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">param</span><span class="p">:</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">state</span><span class="p">:</span> <span class="n">AdamState</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-180'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-180'>#</a>
            </div>
            <p>Bias corrections for <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqbg" style=""><span class="mord" style=""><span class="mord accent" style=""><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal" style="">m</span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord" style="">^</span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span>: <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord">1</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1.072336em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord coloredeq eqbl" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqbs" style="margin-right:0.05278em">β</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8778959999999999em;"><span style="top:-3.1473400000000002em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqce" style=""><span class="mord mathnormal mtight" style="">t</span></span></span></span></span></span></span></span></span></span></span></span></span> and for <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqbh" style=""><span class="mord" style=""><span class="mord accent" style=""><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal" style="margin-right:0.03588em">v</span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.22222em;"><span class="mord" style="">^</span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span>: <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord">1</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1.072336em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord coloredeq eqbm" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqbs" style="margin-right:0.05278em">β</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8778959999999999em;"><span style="top:-3.1473400000000002em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqce" style=""><span class="mord mathnormal mtight" style="">t</span></span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">819</span>        <span class="n">bias_correction</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span> <span class="o">-</span> <span class="n">beta</span> <span class="o">**</span> <span class="n">n_steps</span> <span class="k">for</span> <span class="n">beta</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">betas</span><span class="p">]</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-181'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-181'>#</a>
            </div>
            <p>Uncorrected first and second moments <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqbx" style=""><span class="mord" style=""><span class="mord mathnormal" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqby" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">v</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">821</span>        <span class="n">m</span><span class="p">,</span> <span class="n">v</span> <span class="o">=</span> <span class="n">state</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-182'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-182'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.822356em;vertical-align:-0.49275599999999986em;"></span><span class="mord coloredeq eql" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqbq" style="margin-right:0.0037em">α</span></span><span class="mord" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.3296000000000001em;"><span style="top:-2.643352em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style="">1</span><span class="mbin mtight" style="">−</span><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqbl" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqbs" style="margin-right:0.05278em">β</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31731428571428577em;"><span style="top:-2.357em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.7809257142857142em;"><span style="top:-2.841582857142857em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span></span></span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.5734925em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord sqrt mtight" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.0801535714285715em;"><span class="svg-align" style="top:-3.428571428571429em;"><span class="pstrut" style="height:3.428571428571429em;"></span><span class="mord mtight" style="padding-left:1.19em"><span class="mord mtight" style="">1</span><span class="mbin mtight" style="">−</span><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqbm" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqbs" style="margin-right:0.05278em">β</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31731428571428577em;"><span style="top:-2.357em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.7809257142857142em;"><span style="top:-2.841582857142857em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span></span></span></span></span></span></span></span></span></span><span style="top:-3.0521535714285717em;"><span class="pstrut" style="height:3.428571428571429em;"></span><span class="hide-tail mtight" style="min-width:0.853em;height:1.5428571428571431em"><svg height="1.5428571428571431em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1080" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.37641785714285714em;"><span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.49275599999999986em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">824</span>        <span class="n">step_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">lr</span> <span class="o">*</span> <span class="p">(</span><span class="n">bias_correction</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">**</span> <span class="mf">0.5</span><span class="p">)</span> <span class="o">/</span> <span class="n">bias_correction</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-183'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-183'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.04em;vertical-align:-0.31472em;"></span><span class="mord coloredeq eqs" style=""><span class="mord sqrt" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.72528em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style="padding-left:0.833em"><span class="mord" style=""><span class="mord coloredeq eqby" style=""><span class="mord mathnormal" style="margin-right:0.03588em">v</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span><span style="top:-2.68528em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail" style="min-width:0.853em;height:1.08em"><svg height="1.08em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1080" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.31472em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style=""><span class="mord accent coloredeq eqbb" style=""><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqbj" style="">ϵ</span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.19444em;"><span class="mord" style="">^</span></span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">826</span>        <span class="n">den</span> <span class="o">=</span> <span class="p">(</span><span class="n">v</span> <span class="o">**</span> <span class="mf">0.5</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-184'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-184'>#</a>
            </div>
            <p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqbk" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.02778em">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">←</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.902771em;vertical-align:-0.208331em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.301108em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqce" style=""><span class="mord mathnormal mtight" style="">t</span></span><span class="mbin mtight">−</span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1.822356em;vertical-align:-0.49275599999999986em;"></span><span class="mord coloredeq eql" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqbq" style="margin-right:0.0037em">α</span></span><span class="mord" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.3296000000000001em;"><span style="top:-2.643352em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style="">1</span><span class="mbin mtight" style="">−</span><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqbl" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqbs" style="margin-right:0.05278em">β</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31731428571428577em;"><span style="top:-2.357em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.7809257142857142em;"><span style="top:-2.841582857142857em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span></span></span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.5734925em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord sqrt mtight" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.0801535714285715em;"><span class="svg-align" style="top:-3.428571428571429em;"><span class="pstrut" style="height:3.428571428571429em;"></span><span class="mord mtight" style="padding-left:1.19em"><span class="mord mtight" style="">1</span><span class="mbin mtight" style="">−</span><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqbm" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqbs" style="margin-right:0.05278em">β</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31731428571428577em;"><span style="top:-2.357em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.7809257142857142em;"><span style="top:-2.841582857142857em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span></span></span></span></span></span></span></span></span></span><span style="top:-3.0521535714285717em;"><span class="pstrut" style="height:3.428571428571429em;"></span><span class="hide-tail mtight" style="min-width:0.853em;height:1.5428571428571431em"><svg height="1.5428571428571431em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1080" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.37641785714285714em;"><span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.49275599999999986em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">⋅</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1.2704084999999998em;vertical-align:-0.5589165em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.7114919999999999em;"><span style="top:-2.655em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqs" style=""><span class="mord sqrt mtight" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.734405em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mtight" style="padding-left:0.833em"><span class="mord mtight" style=""><span class="mord mtight coloredeq eqby" style=""><span class="mord mathnormal mtight" style="margin-right:0.03588em">v</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.29634285714285713em;"><span style="top:-2.357em;margin-left:-0.03588em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span></span></span><span style="top:-2.6944049999999997em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail mtight" style="min-width:0.853em;height:1.08em"><svg height="1.08em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1080" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.30559500000000006em;"><span></span></span></span></span></span><span class="mbin mtight" style="">+</span><span class="mord mtight" style=""><span class="mord accent mtight coloredeq eqbb" style=""><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-2.7em;"><span class="pstrut" style="height:2.7em;"></span><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqbj" style="">ϵ</span></span></span><span style="top:-2.7em;"><span class="pstrut" style="height:2.7em;"></span><span class="accent-body" style="left:-0.19444em;"><span class="mord mtight" style="">^</span></span></span></span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.4101em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqbx" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.29634285714285713em;"><span style="top:-2.357em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.5589165em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">830</span>        <span class="k">return</span> <span class="n">param</span> <span class="o">-</span> <span class="n">step_size</span> <span class="o">*</span> <span class="n">m</span> <span class="o">/</span> <span class="n">den</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-185'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-185'>#</a>
            </div>
            <h3>Update state</h3>
<p>This updates uncorrected first and second moments <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqbx" style=""><span class="mord" style=""><span class="mord mathnormal" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqby" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">v</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span></p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">832</span>    <span class="k">def</span> <span class="nf">_update_state</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">grad</span><span class="p">,</span> <span class="n">state</span><span class="p">:</span> <span class="n">AdamState</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-186'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-186'>#</a>
            </div>
            <p>Uncorrected first and second moments <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.638891em;vertical-align:-0.208331em;"></span><span class="mord coloredeq eqbn" style=""><span class="mord" style=""><span class="mord mathnormal" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.301108em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span><span class="mbin mtight" style="">−</span><span class="mord mtight" style="">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span></span></span></span></span></span> and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.638891em;vertical-align:-0.208331em;"></span><span class="mord coloredeq eqbp" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">v</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.301108em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span><span class="mbin mtight" style="">−</span><span class="mord mtight" style="">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">839</span>        <span class="n">m</span><span class="p">,</span> <span class="n">v</span> <span class="o">=</span> <span class="n">state</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-187'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-187'>#</a>
            </div>
            <p>Clip gradients </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">841</span>        <span class="n">grad</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">grad</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-188'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-188'>#</a>
            </div>
            <p><span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqbx" style=""><span class="mord" style=""><span class="mord mathnormal" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">←</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.902771em;vertical-align:-0.208331em;"></span><span class="mord coloredeq eqbl" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqbs" style="margin-right:0.05278em">β</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mord coloredeq eqbn" style=""><span class="mord" style=""><span class="mord mathnormal" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.301108em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span><span class="mbin mtight" style="">−</span><span class="mord mtight" style="">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">(</span><span class="mord">1</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqbl" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqbs" style="margin-right:0.05278em">β</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">⋅</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqbw" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">g</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">843</span>        <span class="n">m</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">betas</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">m</span> <span class="o">+</span> <span class="n">grad</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">betas</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-189'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-189'>#</a>
            </div>
            <p><span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqby" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">v</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">←</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.902771em;vertical-align:-0.208331em;"></span><span class="mord coloredeq eqbm" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqbs" style="margin-right:0.05278em">β</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mord coloredeq eqbp" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">v</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.301108em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span><span class="mbin mtight" style="">−</span><span class="mord mtight" style="">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">(</span><span class="mord">1</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqbm" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqbs" style="margin-right:0.05278em">β</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">⋅</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1.0585479999999998em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord coloredeq eqbw" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">g</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqce" style="">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8641079999999999em;"><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span></span></span></span></span></span> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">845</span>        <span class="n">v</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">betas</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">v</span> <span class="o">+</span> <span class="p">(</span><span class="n">grad</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">betas</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-190'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-190'>#</a>
            </div>
            <p>Return the new state </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">848</span>        <span class="k">return</span> <span class="n">AdamState</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-191'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-191'>#</a>
            </div>
            <p> <a id="Dataset"></a></p>
<h2>Tiny Shakespeare dataset</h2>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">851</span><span class="k">class</span> <span class="nc">TinyShakespeare</span><span class="p">:</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-192'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-192'>#</a>
            </div>
            <ul><li><code  class="highlight"><span></span><span class="n">rnd_key</span></code>
 is the PRNG state </li>
<li><code  class="highlight"><span></span><span class="n">seq_len</span></code>
 is the sequence length of a sample </li>
<li><code  class="highlight"><span></span><span class="n">batch_size</span></code>
 is the batch size</li></ul>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">858</span>    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rnd_key</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">PRNGKey</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-193'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-193'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">865</span>        <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="n">batch_size</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-194'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-194'>#</a>
            </div>
            <p>PRNG key for shuffling the samples </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">867</span>        <span class="n">_</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">rnd_key</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">rnd_key</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-195'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-195'>#</a>
            </div>
            <p>Local path of the text file </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">870</span>        <span class="n">path</span> <span class="o">=</span> <span class="n">lab</span><span class="o">.</span><span class="n">get_data_path</span><span class="p">()</span> <span class="o">/</span> <span class="s1">&#39;tiny_shakespeare.txt&#39;</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-196'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-196'>#</a>
            </div>
            <p>Download if it doesn&#x27;t exist </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">872</span>        <span class="n">url</span> <span class="o">=</span> <span class="s1">&#39;https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt&#39;</span>
<span class="lineno">873</span>        <span class="k">if</span> <span class="ow">not</span> <span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">():</span>
<span class="lineno">874</span>            <span class="n">download_file</span><span class="p">(</span><span class="n">url</span><span class="p">,</span> <span class="n">path</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-197'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-197'>#</a>
            </div>
            <p>Read the file </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">877</span>        <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">path</span><span class="p">),</span> <span class="s1">&#39;r&#39;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="lineno">878</span>            <span class="bp">self</span><span class="o">.</span><span class="n">text</span> <span class="o">=</span> <span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">()</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-198'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-198'>#</a>
            </div>
            <p>Get the characters/tokens </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">881</span>        <span class="n">tokens</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">text</span><span class="p">)))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-199'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-199'>#</a>
            </div>
            <p>Number of tokens </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">884</span>        <span class="bp">self</span><span class="o">.</span><span class="n">n_tokens</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">tokens</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-200'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-200'>#</a>
            </div>
            <p>Map tokens to ids </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">886</span>        <span class="bp">self</span><span class="o">.</span><span class="n">stoi</span> <span class="o">=</span> <span class="p">{</span><span class="n">t</span><span class="p">:</span> <span class="n">i</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">t</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">tokens</span><span class="p">)}</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-201'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-201'>#</a>
            </div>
            <p>Id to token/character </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">888</span>        <span class="bp">self</span><span class="o">.</span><span class="n">itos</span> <span class="o">=</span> <span class="n">tokens</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-202'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-202'>#</a>
            </div>
            <p>As a list of ids </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">891</span>        <span class="n">data</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="bp">self</span><span class="o">.</span><span class="n">stoi</span><span class="p">[</span><span class="n">s</span><span class="p">]</span> <span class="k">for</span> <span class="n">s</span> <span class="ow">in</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">text</span><span class="p">)])</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-203'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-203'>#</a>
            </div>
            <p>Number of batches </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">893</span>        <span class="bp">self</span><span class="o">.</span><span class="n">n_batches</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> <span class="o">//</span> <span class="p">(</span><span class="n">seq_len</span> <span class="o">*</span> <span class="n">batch_size</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-204'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-204'>#</a>
            </div>
            <p>Truncate </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">895</span>        <span class="n">data</span> <span class="o">=</span> <span class="n">data</span><span class="p">[:</span><span class="bp">self</span><span class="o">.</span><span class="n">n_batches</span> <span class="o">*</span> <span class="n">seq_len</span> <span class="o">*</span> <span class="n">batch_size</span><span class="p">]</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-205'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-205'>#</a>
            </div>
            <p>Reshape into a samples (better to use random offsets, but lets ignore that here) </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">897</span>        <span class="bp">self</span><span class="o">.</span><span class="n">data</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-206'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-206'>#</a>
            </div>
            <p>List of sample indexes </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">899</span>        <span class="bp">self</span><span class="o">.</span><span class="n">idx</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-207'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-207'>#</a>
            </div>
            <p> Setup for iteration</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">901</span>    <span class="k">def</span> <span class="fm">__iter__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-208'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-208'>#</a>
            </div>
            <p>Iteration step </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">906</span>        <span class="bp">self</span><span class="o">.</span><span class="n">_iter_idx</span> <span class="o">=</span> <span class="mi">0</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-209'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-209'>#</a>
            </div>
            <p>Split PRNG key </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">908</span>        <span class="bp">self</span><span class="o">.</span><span class="n">rnd_key</span><span class="p">,</span> <span class="n">rnd_key</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">rnd_key</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-210'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-210'>#</a>
            </div>
            <p>Shuffle sample indexes </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">910</span>        <span class="bp">self</span><span class="o">.</span><span class="n">idx</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">permutation</span><span class="p">(</span><span class="n">rnd_key</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">idx</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-211'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-211'>#</a>
            </div>
            <p> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">913</span>        <span class="k">return</span> <span class="bp">self</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-212'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-212'>#</a>
            </div>
            <p> Number of batches</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">915</span>    <span class="k">def</span> <span class="fm">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-213'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-213'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">919</span>        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_batches</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-214'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-214'>#</a>
            </div>
            <p> Get next batch</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">921</span>    <span class="k">def</span> <span class="fm">__next__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-215'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-215'>#</a>
            </div>
            <p>Stop iteration after iterating through all batches </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">927</span>        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_iter_idx</span> <span class="o">&gt;=</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_batches</span><span class="p">:</span>
<span class="lineno">928</span>            <span class="k">raise</span> <span class="ne">StopIteration</span><span class="p">()</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-216'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-216'>#</a>
            </div>
            <p>Sample indexes for the batch </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">931</span>        <span class="n">idx</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">idx</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_iter_idx</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">:(</span><span class="bp">self</span><span class="o">.</span><span class="n">_iter_idx</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">]</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-217'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-217'>#</a>
            </div>
            <p>Increment iteration step </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">933</span>        <span class="bp">self</span><span class="o">.</span><span class="n">_iter_idx</span> <span class="o">+=</span> <span class="mi">1</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-218'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-218'>#</a>
            </div>
            <p>Return samples </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">936</span>        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-219'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-219'>#</a>
            </div>
            <p> <a id="Experiment"></a></p>
<h2>Run the experiment</h2>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">939</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-220'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-220'>#</a>
            </div>
            <p>Create experiment </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">947</span>    <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;jax&#39;</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-221'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-221'>#</a>
            </div>
            <p>Create PRNG key </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">949</span>    <span class="n">rnd_key</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">PRNGKey</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-222'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-222'>#</a>
            </div>
            <p>Create dataset </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">951</span>    <span class="n">dataset</span> <span class="o">=</span> <span class="n">TinyShakespeare</span><span class="p">(</span><span class="n">rnd_key</span><span class="p">,</span> <span class="n">seq_len</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">128</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-223'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-223'>#</a>
            </div>
            <p>Create the model </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">954</span>    <span class="n">model</span> <span class="o">=</span> <span class="n">AutoregressiveTransformer</span><span class="p">(</span><span class="n">rnd_key</span><span class="p">,</span> <span class="n">dataset</span><span class="o">.</span><span class="n">n_tokens</span><span class="p">,</span>
<span class="lineno">955</span>                                      <span class="n">d_model</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span> <span class="n">n_layers</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">heads</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">d_ff</span><span class="o">=</span><span class="mi">512</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-224'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-224'>#</a>
            </div>
            <p>Get model parameters </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">957</span>    <span class="n">params</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">get_params</span><span class="p">()</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-225'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-225'>#</a>
            </div>
            <p>JAX compiled pure sampling function </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">960</span>    <span class="n">pure_sample_fn</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">jit</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">purify</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">sample</span><span class="p">))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-226'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-226'>#</a>
            </div>
            <p>JAX compiled pure function to get logits for a batch. First we transform <code  class="highlight"><span></span><span class="n">model</span><span class="o">.</span><span class="fm">__call__</span></code>
 to a pure function which accepts two arguments: parameters, and input sequence. Next we vectorize the function to process a batch of samples. <code  class="highlight"><span></span><span class="n">in_axes</span></code>
 specifies which arguments to parallelize and along which axis. <code  class="highlight"><span></span><span class="p">(</span><span class="kc">None</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span></code>
 means we have the same parameters but parallelize the inputs across the first axis. <code  class="highlight"><span></span><span class="n">out_axes</span></code>
 specifies along which axis to merge the results. </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">968</span>    <span class="n">pure_forward_fn</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">jit</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">vmap</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">purify</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="fm">__call__</span><span class="p">),</span>
<span class="lineno">969</span>                                       <span class="n">in_axes</span><span class="o">=</span><span class="p">(</span><span class="kc">None</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="n">out_axes</span><span class="o">=</span><span class="mi">0</span><span class="p">))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-227'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-227'>#</a>
            </div>
            <p>Similarly we vectorize loss computation </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">971</span>    <span class="n">pure_loss_fn</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">jit</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">vmap</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">purify</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">get_loss</span><span class="p">),</span>
<span class="lineno">972</span>                                    <span class="n">in_axes</span><span class="o">=</span><span class="p">(</span><span class="kc">None</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="n">out_axes</span><span class="o">=</span><span class="mi">0</span><span class="p">))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-228'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-228'>#</a>
            </div>
            <p>A function to get mean loss </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">975</span>    <span class="k">def</span> <span class="nf">get_loss</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">seq</span><span class="p">):</span>
<span class="lineno">976</span>        <span class="k">return</span> <span class="n">pure_loss_fn</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">seq</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-229'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-229'>#</a>
            </div>
            <p>A function to compute gradients for the first argument (parameters) </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">979</span>    <span class="n">grad_loss_fn</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">jit</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">grad</span><span class="p">(</span><span class="n">get_loss</span><span class="p">,</span> <span class="n">argnums</span><span class="o">=</span><span class="mi">0</span><span class="p">))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-230'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-230'>#</a>
            </div>
            <p>Create optimizer </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">982</span>    <span class="n">optimizer</span> <span class="o">=</span> <span class="n">Adam</span><span class="p">(</span><span class="n">params</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-231'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-231'>#</a>
            </div>
            <p>Start the experiment </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">985</span>    <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-232'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-232'>#</a>
            </div>
            <p>Iterate for 32 epochs </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">987</span>        <span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">loop</span><span class="p">(</span><span class="mi">32</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-233'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-233'>#</a>
            </div>
            <p>Iterate through batches </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">989</span>            <span class="k">for</span> <span class="n">data</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">iterate</span><span class="p">(</span><span class="s1">&#39;Train&#39;</span><span class="p">,</span> <span class="n">dataset</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-234'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-234'>#</a>
            </div>
            <p>Compute and log the loss </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">991</span>                <span class="n">loss</span> <span class="o">=</span> <span class="n">get_loss</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">data</span><span class="p">)</span>
<span class="lineno">992</span>                <span class="n">tracker</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="s1">&#39;loss&#39;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">loss</span><span class="p">))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-235'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-235'>#</a>
            </div>
            <p>Get the gradients </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">994</span>                <span class="n">grads</span> <span class="o">=</span> <span class="n">grad_loss_fn</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">data</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-236'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-236'>#</a>
            </div>
            <p>Update parameters </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">996</span>                <span class="n">params</span> <span class="o">=</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">grads</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-237'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-237'>#</a>
            </div>
            <p> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">999</span>            <span class="n">tracker</span><span class="o">.</span><span class="n">new_line</span><span class="p">()</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-238'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-238'>#</a>
            </div>
            <p>Log a sample after each epoch </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">1001</span>            <span class="n">prompt</span> <span class="o">=</span> <span class="p">[</span><span class="n">dataset</span><span class="o">.</span><span class="n">stoi</span><span class="p">[</span><span class="n">c</span><span class="p">]</span> <span class="k">for</span> <span class="n">c</span> <span class="ow">in</span> <span class="s1">&#39;It &#39;</span><span class="p">]</span>
<span class="lineno">1002</span>            <span class="n">sampled</span> <span class="o">=</span> <span class="n">pure_sample_fn</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">prompt</span><span class="p">))[</span><span class="nb">len</span><span class="p">(</span><span class="n">prompt</span><span class="p">):]</span>
<span class="lineno">1003</span>            <span class="n">sampled</span> <span class="o">=</span> <span class="s1">&#39;&#39;</span><span class="o">.</span><span class="n">join</span><span class="p">([</span><span class="n">dataset</span><span class="o">.</span><span class="n">itos</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">sampled</span><span class="p">])</span>
<span class="lineno">1004</span>            <span class="n">sampled</span> <span class="o">=</span> <span class="n">sampled</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s1">&#39;</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">,</span> <span class="s1">&#39;</span><span class="se">\\</span><span class="s1">n&#39;</span><span class="p">)</span>
<span class="lineno">1005</span>            <span class="n">logger</span><span class="o">.</span><span class="n">log</span><span class="p">((</span><span class="s1">&#39;It &#39;</span><span class="p">,</span> <span class="n">Text</span><span class="o">.</span><span class="n">meta</span><span class="p">),</span> <span class="p">(</span><span class="n">sampled</span><span class="p">,</span> <span class="n">Text</span><span class="o">.</span><span class="n">value</span><span class="p">))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-239'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-239'>#</a>
            </div>
            <p> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">1009</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">1010</span>    <span class="n">main</span><span class="p">()</span></pre></div>
        </div>
    </div>
    <div class='footer'>
        <a href="https://labml.ai">labml.ai</a>
    </div>
</div>
<script src=../../interactive.js?v=1"></script>
<script>
    function handleImages() {
        var images = document.querySelectorAll('p>img')

        for (var i = 0; i < images.length; ++i) {
            handleImage(images[i])
        }
    }

    function handleImage(img) {
        img.parentElement.style.textAlign = 'center'

        var modal = document.createElement('div')
        modal.id = 'modal'

        var modalContent = document.createElement('div')
        modal.appendChild(modalContent)

        var modalImage = document.createElement('img')
        modalContent.appendChild(modalImage)

        var span = document.createElement('span')
        span.classList.add('close')
        span.textContent = 'x'
        modal.appendChild(span)

        img.onclick = function () {
            console.log('clicked')
            document.body.appendChild(modal)
            modalImage.src = img.src
        }

        span.onclick = function () {
            document.body.removeChild(modal)
        }
    }

    handleImages()
</script>
</body>
</html>